In [None]:
from torchvision import transforms
from PIL import Image 
import numpy as np
import os
import json
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split

#from torch.optim.lr_scheduler import CosineAnnealingLR

#from transformers import CLIPProcessor, CLIPModel
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
#from transformers import AutoModel, AutoTokenizer

from tqdm import tqdm
from transformers import BertModel, BertTokenizer
import re

In [None]:
# GPU Clean up
torch.cuda.empty_cache()
#torch.cuda.reset_max_memory_allocated()

In [None]:
class DataProcessor:
    def __init__(self, folder_path, disease_list):
        self.folder_path = folder_path
        self.disease_list = disease_list
        self.folders_with_diseases_labels = {}
        self.folder_name_with_diseases = []
        self.label_counts = None
        self.data = []
        
    def load_data(self):
        
        imgs_folder = os.path.join(self.folder_path, "imgs")
        subdirectories = [d for d in os.listdir(imgs_folder) if os.path.isdir(os.path.join(imgs_folder, d))]

        for folder in subdirectories:
            # detection file: the report of images
            detection_file_path = os.path.join(imgs_folder, folder, 'detection.json')
            with open(detection_file_path, 'r') as detection_file:
                detection_data = json.load(detection_file)

                disease_labels = [label for item in detection_data for label in item.keys() if label in self.disease_list]

                # merge labels for images with multiple labels
                if disease_labels and len(disease_labels)==1:
                    self.folders_with_diseases_labels[folder] = disease_labels[0]
                    self.folder_name_with_diseases.append(folder)
            
                    # question file
                    detection_file_path = os.path.join(imgs_folder, folder, "question.json")
                    if os.path.exists(detection_file_path):
                        with open(detection_file_path, "r") as detection_file:
                            detection_data = json.load(detection_file)
                            img_name = detection_data[0].get("img_name", "")
                            question_data = [item for item in detection_data if "question" in item and item.get("q_lang", "") == "en"]
                            
                            # Concatenate all questions and answers into a single string
                            # remove questions including disease
                            all_qa = " ".join(f"Q: {item['question']} A: {item['answer']}" for item in question_data if "What disease" not in item['question'])
                            
                            self.data.append({
                                "image_path": os.path.join(imgs_folder, img_name),
                                "text": all_qa, 
                                "label": disease_labels[0]
                            })
        
       
    def delete_folders(self):
        # frequency of each merged label
        self.label_counts = Counter(self.folders_with_diseases_labels.values())

        # delete folders with label counts <= 3
        folders_to_delete = [folder_name for folder_name, label in self.folders_with_diseases_labels.items() if self.label_counts[label] <= 3]
        full_paths = [f"Slake1.0/imgs/{folder}/source.jpg" for folder in folders_to_delete]
        
        self.data = [item for item in self.data if item["image_path"] not in full_paths]
        
        for folder_name in folders_to_delete:
            path = 'Slake1.0/imgs'
            
            del self.folders_with_diseases_labels[folder_name]
            self.folder_name_with_diseases.remove(folder_name)
            
            
    def get_training_data(self):
        # Convert labels to indices
        label_to_index = {label: idx for idx, label in enumerate(set(self.folders_with_diseases_labels.values()))}
        self.data = [{"image_path": item["image_path"], "text": item["text"], "label": label_to_index[item["label"]]} for item in self.data]
    
        # 80% for training and 20% for validation
        train_data, val_data = train_test_split(self.data, test_size=0.2, random_state=42, shuffle=True)
        return train_data, val_data


folder_path = 'Slake1.0'
disease_list = ['Pneumothorax', 'Pneumonia', 'Effusion','Cardiomegaly','Lung Cancer']
data_processor = DataProcessor(folder_path, disease_list)
data_processor.load_data()
#data_processor.delete_folders()
train_data, val_data = data_processor.get_training_data()
#training_data = data_processor.get_training_data()


In [None]:
# Preprocessing of texts 
def get_tokens(text, tokenizer):
    tokens = tokenizer.tokenize(text)
    tokens = ["[CLS]"] + tokens + ["[SEP]"]
    length = len(tokens)
    if length > max_length:
        tokens = tokens[:max_length]
    return tokens, length  

def get_masks(text, tokenizer, max_length):
    """Mask for padding"""
    tokens, length = get_tokens(text, tokenizer)
    return np.asarray([1]*len(tokens) + [0] * (max_length - len(tokens)))


def get_segments(text, tokenizer, max_length):
    """Segments: 0 for the first sequence, 1 for the second"""
    tokens, length = get_tokens(text, tokenizer)
    segments = []
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            current_segment_id = 1
    return np.asarray(segments + [0] * (max_length - len(tokens)))

def get_ids(text, tokenizer, max_length):
    """Token ids from Tokenizer vocab"""
    tokens, length = get_tokens(text, tokenizer)
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = np.asarray(token_ids + [0] * (max_length-length))
    return input_ids 

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_processor, tokenizer, max_length, transform=None, is_train=True):
        self.data_processor = data_processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transform
        self.is_train = is_train

    def __len__(self):
        return len(self.data_processor)

    def __getitem__(self, idx):
        image_path = self.data_processor[idx]["image_path"]
        image = Image.open(image_path).convert('RGB')

        if self.is_train:
            if np.random.rand() > 0.5:
                image = image.transpose(Image.FLIP_LEFT_RIGHT)

            if np.random.rand() > 0.5:
                image = image.transpose(Image.FLIP_TOP_BOTTOM)

            angle = np.random.uniform(-30, 30)
            image = image.rotate(angle)

        if self.transform:
            image = self.transform(image)

        text = self.data_processor[idx]["text"]
        label = self.data_processor[idx]["label"]

        # Preprocess text using provided functions
        tokens, length = get_tokens(text, self.tokenizer)
        masks = get_masks(text, self.tokenizer, self.max_length)
        segments = get_segments(text, self.tokenizer, self.max_length)
        ids = get_ids(text, self.tokenizer, self.max_length)

        # numpy arrays to PyTorch tensors
        ids = torch.tensor(ids, dtype=torch.long)
        masks = torch.tensor(masks, dtype=torch.long)
        segments = torch.tensor(segments, dtype=torch.long)

        return {"input_word_ids": ids, "input_mask": masks, "segment_ids": segments, "image": image, "label": label}


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

max_length = 128  

train_dataset = CustomDataset(train_data, tokenizer, max_length, transform=transform, is_train=True)

val_dataset = CustomDataset(val_data, tokenizer, max_length, transform=transform, is_train=False)


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


In [None]:
# Images Model with AlexNet
class ImageModel(nn.Module):
    def __init__(self, num_classes, input_shape=(3, 224, 224)):
        super(ImageModel, self).__init__()
        self.alexnet = models.alexnet(pretrained=True)
        
        self.alexnet.classifier = nn.Sequential(*list(self.alexnet.classifier.children())[:-2])
        
        self.avg_pooling = nn.AdaptiveAvgPool2d((8, 8))
        self.dropout = nn.Dropout(0.4)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(256 * 8 * 8, 128)
        self.num_classes = num_classes

    def forward(self, x):
        x = self.alexnet.features(x)
        x = self.avg_pooling(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

# text model: Bert + LSTM 
class BertLSTMModel(nn.Module):
    def __init__(self, bert_model, max_length, lstm_hidden_size):
        super(BertLSTMModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.lstm = nn.LSTM(input_size=768, hidden_size=lstm_hidden_size, batch_first=True)
        self.max_length = max_length

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        # [CLS] token embedding for LSTM
        lstm_input = outputs.last_hidden_state[:, 0, :]  
        # to add time dimension for LSTM
        lstm_output, _ = self.lstm(lstm_input.unsqueeze(1))  
        return lstm_output.squeeze()

# Fusion Model
class FusionModel(nn.Module):
    def __init__(self, num_classes, lstm_hidden_size=128, dropout_rate=0.4):
        super(FusionModel, self).__init__()
        self.text_model = BertLSTMModel("bert-base-uncased", max_length, lstm_hidden_size)
        self.image_model = ImageModel(num_classes)
        self.fc1 = nn.Linear(256, 128)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, input_ids, attention_mask, token_type_ids, image_input):
        text_output = self.text_model(input_ids, attention_mask, token_type_ids)
        image_output = self.image_model(image_input)
        fused_input = torch.cat([text_output, image_output], dim=1)
        x = self.fc1(fused_input)
        x = self.dropout1(x)
        x = self.fc2(x)
        return x


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_length = 128
lstm_hidden_size = 128
num_classes = 5
batch_size = 4
epochs = 42

model = FusionModel(num_classes, lstm_hidden_size)
model.to(device)

criterion = nn.CrossEntropyLoss()
weight_decay = 5 * 1e-2

optimizer = optim.SGD(model.parameters(), lr=1e-3, weight_decay=weight_decay)

train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_val_accuracy = 0.0
best_model_path = "fusion_model.pth"


for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    correct_predictions = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
        
        input_ids = batch["input_word_ids"].to(device)
        attention_mask = batch["input_mask"].to(device)
        token_type_ids = batch["segment_ids"].to(device)
        image_input = batch["image"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids, attention_mask, token_type_ids, image_input)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()

    avg_loss = total_loss / len(train_loader)
    accuracy = correct_predictions / len(train_dataset)

    print(f"Epoch {epoch + 1}/{epochs}:")
    print(f"Train Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

    train_losses.append(avg_loss)
    train_accuracies.append(accuracy)

    model.eval()
    total_val_loss = 0.0
    correct_val_predictions = 0

    with torch.no_grad():
        for val_batch in tqdm(val_loader, desc=f"Validation Epoch {epoch + 1}/{epochs}"):
            
            val_input_ids = val_batch["input_word_ids"].to(device)
            val_attention_mask = val_batch["input_mask"].to(device)
            val_token_type_ids = val_batch["segment_ids"].to(device)
            val_image_input = val_batch["image"].to(device)
            val_labels = val_batch["label"].to(device)

            val_outputs = model(val_input_ids, val_attention_mask, val_token_type_ids, val_image_input)
            val_loss = criterion(val_outputs, val_labels)

            total_val_loss += val_loss.item()
            _, val_predicted = torch.max(val_outputs, 1)
            correct_val_predictions += (val_predicted == val_labels).sum().item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = correct_val_predictions / len(val_dataset)
    print(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")

    # save the model best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), best_model_path)

    val_losses.append(avg_val_loss)
    val_accuracies.append(val_accuracy)


In [None]:
# Plotting
plt.figure(figsize=(12, 4))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot Accuracy
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()

# to save files
save_path = '/storage/homefs/zh21i037/'
filename = 'losses and accuracies.png'
save_filename = os.path.join(save_path, filename)

plt.savefig(save_filename)
plt.close()