In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from cropping_dataset import CustomImageDataset, DPImageDataset
from build_network import *
from sklearn.metrics import classification_report, roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm
import os
import engine
import warnings
import timm
import sys
sys.path.append('/home/ubuntu/storage2/hyunjoong/cancer/DINO_knn')
import vision_transformer
warnings.filterwarnings('ignore')

os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "0,2"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 1. Prepare for model

In [None]:
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear1 = nn.Linear(dim, 2048, bias = True)
        self.dropout1 = nn.Dropout(p = 0.4)
        self.act1 = nn.GELU()
        
        self.linear2 = nn.Linear(2048, 256, bias = True)
        self.dropout2 = nn.Dropout(p = 0.4)
        self.act2 = nn.GELU()
        
        # self.linear3 = nn.Linear(2048, 256, bias = True)
        # self.dropout3 = nn.Dropout(p = 0.4)
        # self.act3 = nn.GELU()
        
        self.out = nn.Linear(256, num_labels, bias = True)

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.dropout1(x)
        x = self.act1(x)
        
        x = self.linear2(x)
        x = self.dropout2(x)
        x = self.act2(x)
        
        # x = self.linear3(x)
        # x = self.dropout3(x)
        # x = self.act3(x)

        # linear layer
        return self.out(x)

In [None]:
class build_model(nn.Module):
    def __init__(self, vit):
        super().__init__()
        self.vit = vit
        # for param in self.vit.parameters():
        #     param.required_grad = False
        self.classifier = nn.Linear(768, 2, bias = True)
        
    def forward(self, x):
        x = self.vit(x)
        x = self.classifier(x)
        return x


### Custem Model

In [None]:
model = torch.load('./best_model.pth')
model.load_state_dict(torch.load('./checkpoint.pth')['student'], strict = False)
model = model.module.backbone
embed_dim = model.embed_dim
model.head = nn.Linear(embed_dim, 2)
model

## 2. Prepare for datasets

In [None]:
batch_size = 16
num_workers = 5
val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    # dataset_val = datasets.ImageFolder(os.path.join(args.data_path, "val"), transform=val_transform)

train_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomApply([transforms.ColorJitter(brightness = 0.4, contrast = 0.4, saturation = 0.2, hue = 0.1)],
                                   p = 0.8),
            transforms.RandomGrayscale(p = 0.2),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
    # dataset_train = datasets.ImageFolder(os.path.join(args.data_path, "train"), transform=train_transform)


test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])



In [None]:

dataset_val = DPImageDataset('../val3_df.csv', transform = val_transform)
val_loader = torch.utils.data.DataLoader(
        dataset_val,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle = True,
        pin_memory=True,
    )

dataset_train = DPImageDataset('../train3_df.csv', transform = train_transform)
    # dist.init_process_group(backend='nccl', init_method='env://')
train_loader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle = True,
        pin_memory=True,
    )

test_dataset = DPImageDataset('../test3_df.csv', transform = test_transform)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             num_workers=num_workers,
                             pin_memory=True)

## 3. Trainig the model

In [None]:
lr = 0.001
epochs = 100
optimizer = torch.optim.SGD(model.parameters(),
                            lr = lr, # linear scaling rule
                            momentum=0.9,
                            weight_decay=0, # we do not apply weight decay
                            )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer = optimizer, mode = 'min', patience = 3, min_lr=1e-8)
loss_fn = torch.nn.CrossEntropyLoss()
model, hist = engine.train(model = model, 
                           train_dataloader = train_loader, 
                           test_dataloader = val_loader, 
                           optimizer = optimizer,
                           scheduler = scheduler,
                           loss_fn =  loss_fn,
                           epochs = epochs,
                           output_dir = './classifier_checkpoint',
                           device = device,
                           parallel = True
                           )

In [None]:
torch.save(model, './vitb8_dataset2048/vitb8_classifier.pt')

In [None]:
from VisionTransformer import vision_transformer
model= torch.load('./vits16_data2048/checkpoint_model.pt')
model

In [None]:
from sklearn.metrics import classification_report
import itertools

loss_fn = torch.nn.CrossEntropyLoss()

model.eval()

test_acc = 0
test_loss = 0
preds = []
y_labels = []

with torch.no_grad():
    for data,target in tqdm(test_dataloader):
        data, target = data.to(device), target.to(device)
        y_label = list(target.detach().cpu().numpy())
        y_labels.append(y_label)
        output = model(data)
        loss = loss_fn(output, target)
        test_loss += loss.item()

        # Calculate and accumulate accuracy
        # test_pred_labels = output.argmax(dim=1)
        
        
        y_pred_class = torch.argmax(torch.softmax(output, dim=1), dim=1)
        pred = list(y_pred_class.detach().cpu().numpy())
        preds.append(pred)
        test_acc += ((y_pred_class == target).sum().item()/len(y_pred_class))

preds = list(itertools.chain(*preds))
y_labels = list(itertools.chain(*y_labels))
test_loss = test_loss / len(test_dataloader)
test_acc = test_acc / len(test_dataloader)
print(f'test_loss : {test_loss} | test_acc : {test_acc}')

report = classification_report(preds, y_labels)
print(report)

In [None]:
from sklearn.metrics import recall_score, f1_score, precision_score
print(f"precision : {precision_score(preds, y_labels)}")
print(f"recall : {recall_score(preds, y_labels)}")
print(f"f1_score : {f1_score(preds, y_labels)}")