In [None]:
import os, gc, cv2, math, copy, time, random
import pickle

import numpy as np, pandas as pd
from collections import defaultdict

import torch, torch.nn as nn, torch.optim as optim
from torch.optim import lr_scheduler
from torch.cuda import amp
import torch.backends.cudnn as cudnn
import threading


import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import f1_score,roc_auc_score
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import recall_score

from tqdm.notebook import tqdm
from tqdm import tqdm

import torch.nn.functional as F
import timm
import ast

from torch.utils.data import Dataset, DataLoader

In [None]:
CONFIG = {
        "img_size": 384,
        "valid_batch_size":1,
        "device": torch.device("cuda"),
        }

In [None]:
data_transforms =  A.Compose([
    A.Resize(CONFIG['img_size'], CONFIG['img_size']),
    A.Normalize(),
    ToTensorV2()], p=1.)

In [None]:
files_dir='/home/Data/rmf3mc/Challenge/Challenge1_Testset/files'
test_set = pd.read_csv(os.path.join(files_dir,'la-test_path_range.csv'))

In [None]:
class ChallengeDataset(Dataset):
    def __init__(self, df, transforms=None, batch_size=16):

        self.img_paths = df['Path'].tolist()
        self.transforms = transforms
        
        self.ranges = df['Range'].apply(ast.literal_eval).tolist()
        self.batch_size = batch_size

    def __len__(self):

        return len(self.img_paths)

    def __getitem__(self, idx):

        # Get the image path and label for the current index
        img_path = self.img_paths[idx]
        range_list= self.ranges[idx]
        

        img_path = img_path.replace('test_set_preprocessed', 'test_set')
         
              
        sampled_paths = np.round(np.linspace(range_list[0], range_list[1], self.batch_size, endpoint=False)).astype(int)
#         images = torch.empty((self.batch_size, 3, 384, 384))  # Example size: [3, 224, 224]
        images=[]
        
        for i, path in enumerate(sampled_paths):
            img = cv2.imread(os.path.join(img_path, f"{path}.jpg") )
            if img is None:
                continue

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            
            img = self.transforms(image=img)['image']
            tensor = torch.tensor(img)
            
            images.append(tensor)
        tensor=torch.stack(images)
        return {
            'image': tensor, 'path': img_path,
        }

In [None]:
test_dataset=ChallengeDataset(test_set,data_transforms,40)
test_loader = DataLoader(test_dataset, batch_size=1,  shuffle=False)

In [None]:
class eca_nfnet_l0(nn.Module):
    def __init__(self):
        super(eca_nfnet_l0, self).__init__()

        self.model = timm.create_model("hf_hub:timm/eca_nfnet_l0", pretrained=True)
        self.classifier = nn.Linear(self.model.head.fc.in_features, 1, bias=True)
        
        self.attention = nn.Conv2d(2, 1, kernel_size=1, bias=True)
        
        layer_name = 'final_conv'
        
        self.features = {}
        
        self.model.final_act.register_forward_hook(self.get_features)

    def set_features(self, features):
        self.features = features

    def get_features(self, module, input, output):
        self.features[threading.get_ident()] = output

    def getAttFeats(self, att_map, features):
        features = 0.5 * features + 0.5 * (att_map * features)
        return features

    def forward(self, x):
        outputs = {}
        
        dummy = self.model(x)
        
        features = self.features[threading.get_ident()]
        fg_att = self.attention(torch.cat((torch.mean(features, dim=1).unsqueeze(1), torch.max(features, dim=1)[0].unsqueeze(1)), dim=1))
        fg_att = torch.sigmoid(fg_att)
        features = self.getAttFeats(fg_att, features)
        
        out = F.adaptive_avg_pool2d(features, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        
        outputs['logits'] = out
        outputs['feat'] = features
        
        return out
    
    
bin_save_path = "/mnt/saved_models/14batch_eachof12_combing_challenge1_2training_challene2validation"
job_name = f"epoch:{CONFIG['epochs']}_ECA_Attention_{CONFIG['img_size']}"
model = eca_nfnet_l0()
print(model)

In [None]:
model_path=os.path.join(files_dir,'model.pth')
state_dict = torch.load(model_path)

model.load_state_dict(state_dict)

In [None]:
model = model.to(CONFIG['device'])

In [None]:
non_covid=[]
covid=[]
@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    pred_y=[]
    
    for step, data in bar:
        ct_b, img_b, c, h, w = data['image'].size()
        data_img = data['image'].reshape(-1, c, h, w)
        path=data['path'][0]
        file_name = path.split("/")[-1]
        
        images = data_img.to(device, dtype=torch.float)


        outputs = model(images)

        pred_y=torch.sigmoid(outputs).cpu().numpy()
        if pred_y.mean()<0.5:
            non_covid.append(file_name)
        else:
            covid.append(file_name)


In [None]:
valid_one_epoch(model, test_loader, CONFIG['device'], 1)

In [None]:
import csv
with open(os.path.join(files_dir,'non-covid.csv'), mode='w', newline='') as file:
    writer = csv.writer(file)

    for element in non_covid:
        writer.writerow([element])

In [None]:
with open(os.path.join(files_dir,'covid.csv'), mode='w', newline='') as file:
    writer = csv.writer(file)

    for element in covid:
        writer.writerow([element])