In [34]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch
import torch.nn as nn
from modules import SAB, PMA
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import os
import pandas as pd

num_classes = 4


class SmallDeepSet(nn.Module):
    def __init__(self, pool="max"):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Linear(in_features=2048, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=64),
        )
        self.dec = nn.Sequential(
            nn.Linear(in_features=64, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=num_classes),
        )
        self.pool = pool

    def forward(self, x):
        x = self.enc(x)
        if self.pool == "max":
            x = x.max(dim=1)[0]
        elif self.pool == "mean":
            x = x.mean(dim=1)
        elif self.pool == "sum":
            x = x.sum(dim=1)
        x = self.dec(x)
        return x


class SmallSetTransformer(nn.Module):
    def __init__(self,):
        super().__init__()
        self.enc = nn.Sequential(
            SAB(dim_in=2048, dim_out=64, num_heads=4),
            SAB(dim_in=64, dim_out=64, num_heads=4),
        )
        self.dec = nn.Sequential(
            PMA(dim=64, num_heads=4, num_seeds=1),
            nn.Linear(in_features=64, out_features=num_classes),
        )

    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        return x.squeeze(-1)



# Dataset class definition
class SetDataset(Dataset):
    def __init__(self, data_folder, annotations_file, mode='train'):
        self.data_folder = data_folder
        self.annotations = pd.read_csv(annotations_file)
        self.mode = mode
        self.annotations['path'] = self.annotations['filename'] + '_' + self.annotations['x_y'] + '.pt'
        self.annotations.set_index('path', inplace=True)
        self.data_files = [f for f in os.listdir(data_folder) if f.endswith('.pt')]
        if mode == 'train':
            self.data_files = [f for f in self.data_files if not any(x in f for x in ['case3', 'case4', 'control3', 'control4'])]
        else:
            self.data_files = [f for f in self.data_files if any(x in f for x in ['case3', 'case4', 'control3', 'control4'])]

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, index):
        file_path = self.data_files[index]
        data = torch.load(os.path.join(self.data_folder, file_path))
        label = self.annotations.loc[file_path, 'level']
        label = torch.tensor(label, dtype=torch.long)
        return data, label

# # Collate function to handle variable-sized data
# def collate_fn(batch):
#     xs, ys = zip(*batch)
#     return list(xs), torch.tensor(ys, dtype=torch.long)
def collate_fn(batch):
    xs, ys = zip(*batch)
    # Pad the sequences so they all have the same length
    xs = pad_sequence(xs, batch_first=True)  # Pads with zero by default
    ys = torch.tensor(ys, dtype=torch.long)
    return xs, ys



def train(model, data_loader, epochs, save_path):
    # Move model to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss().to(device)  # Using CrossEntropyLoss as it is common for classification tasks
    losses = []

    # Set model to training mode
    model.train()
    
    for epoch in range(epochs):
        for data, label in data_loader:
            data = data.to(device)
            label = label.to(device)
            output = model(data)
            output = output.squeeze(1)  # This changes shape from [20, 1, 4] to [20, 4]
#             print(output)
#             print(label)
#             print("-"*60)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        print(f'Epoch {epoch+1}, Loss: {loss.item()}')  # Print loss for the epoch

    torch.save(model.state_dict(), save_path)
    return losses

def evaluate(model, data_loader, load_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Load the saved weights
    model.load_state_dict(torch.load(load_path))

    model.eval()
    criterion = nn.CrossEntropyLoss().to(device)
    total_loss = 0
    total_correct = 0
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            label = label.to(device)
            output = model(data)
            output = output.squeeze(1)
            loss = criterion(output, label)
            total_loss += loss.item()
            predictions = output.argmax(dim=1, keepdim=True)
            total_correct += predictions.eq(label.view_as(predictions)).sum().item()

    avg_loss = total_loss / len(data_loader)
    accuracy = total_correct / len(data_loader.dataset)
    print(f'Test Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')
    return avg_loss, accuracy


data_folder = '/projectnb/ec500kb/projects/project7/GTEx/annotated_patches/resnet_features' # the path for the stored features
annotations_file = '/projectnb/ec500kb/projects/project7/GTEx/annotated_patches/annotations.csv'
train_dataset = SetDataset(data_folder, annotations_file, mode='train')
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_fn)

test_dataset = SetDataset(data_folder, annotations_file, mode='test')
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=True, collate_fn=collate_fn)

In [None]:
# Assuming you have a Set Transformer model defined as 'model'
model = SmallSetTransformer()

train_path = '/projectnb/ec500kb/projects/project7/set_transformer/model_weights.pth'
evaluate_path = '/projectnb/ec500kb/projects/project7/set_transformer/model_weights.pth'

train(model, train_loader, epochs=100, save_path=train_path)
evaluate(model, test_loader, load_path=evaluate_path)

In [35]:
# Assuming you have a Set Transformer model defined as 'model'
model = SmallDeepSet()

train_path = '/projectnb/ec500kb/projects/project7/set_transformer/model_weights_deep_set.pth'
evaluate_path = '/projectnb/ec500kb/projects/project7/set_transformer/model_weights_deep_set.pth'

train(model, train_loader, epochs=100, save_path=train_path)
evaluate(model, test_loader, load_path=evaluate_path)

Epoch 1, Loss: 1.3716249465942383
Epoch 2, Loss: 1.2860463857650757
Epoch 3, Loss: 1.4683560132980347
Epoch 4, Loss: 1.3566828966140747
Epoch 5, Loss: 1.1399704217910767
Epoch 6, Loss: 1.1359277963638306
Epoch 7, Loss: 1.4350744485855103
Epoch 8, Loss: 1.026046872138977
Epoch 9, Loss: 1.3894871473312378
Epoch 10, Loss: 0.9681815505027771
Epoch 11, Loss: 1.458438515663147
Epoch 12, Loss: 1.2430943250656128
Epoch 13, Loss: 1.1677699089050293
Epoch 14, Loss: 0.9662911295890808
Epoch 15, Loss: 0.8242725729942322
Epoch 16, Loss: 1.42889404296875
Epoch 17, Loss: 1.0772634744644165
Epoch 18, Loss: 0.9825153946876526
Epoch 19, Loss: 1.156739354133606
Epoch 20, Loss: 1.42207670211792
Epoch 21, Loss: 1.1884562969207764
Epoch 22, Loss: 1.0731812715530396
Epoch 23, Loss: 1.3973110914230347
Epoch 24, Loss: 1.1313308477401733
Epoch 25, Loss: 1.0063930749893188
Epoch 26, Loss: 1.1348589658737183
Epoch 27, Loss: 1.4415658712387085
Epoch 28, Loss: 1.5732933282852173
Epoch 29, Loss: 0.9724490642547607
E

(1.5559093952178955, 0.35)