In [28]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [30]:
from torch_tomogram_dataset.utils import calculate_edge_mean_2d
from torch_tomogram_dataset.transforms import augmentation_transform_2d

class Qata(Dataset):
    def __init__(self, root_dir, transform=None, augmentation=False):
        super().__init__()
        self.root_dir = root_dir
        self.transform = transform
        self.augmentation = augmentation
        self.samples = []
        self.labels = []
        
        states = ['good', 'bad']
        for label, state in enumerate(states):
            path = os.path.join(self.root_dir, state)
            for file in os.listdir(path):
                data = Image.open(os.path.join(path, file)).convert('L')
                data = torch.from_numpy(np.array(data))
                
                self.samples.append(data)
                self.labels.append(label)
                
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        image_tensor = self.samples[idx]
        image_tensor = image_tensor.float().unsqueeze(0).repeat(3, 1, 1)
        label = self.labels[idx]
        
        if self.transform is not None:
            image_tensor = self.transform(image_tensor)
        
        # Rotation with edge mean filling
        if self.augmentation:
            edge_mean = calculate_edge_mean_2d(image_tensor)
            image_tensor = augmentation_transform_2d(image_tensor)
            image_tensor[image_tensor==0] = edge_mean

        return image_tensor, label

In [31]:
from torch_tomogram_dataset import AugmentedDatasetWrapper
from torchvision import models

train_dir = r"C:\rkka_Projects\cell_death_v1\Data\qpi_output\quality\train"
val_dir = r"C:\rkka_Projects\cell_death_v1\Data\qpi_output\quality\val"

transform = models.ResNet50_Weights.IMAGENET1K_V2.transforms()

train_dataset = Qata(root_dir=train_dir, transform=transform, augmentation=True)
augmented_train_dataset = AugmentedDatasetWrapper(dataset=train_dataset, num_repeats=3)
val_dataset = Qata(root_dir=val_dir, transform=transform, augmentation=False)

train_loader = DataLoader(augmented_train_dataset, batch_size=32)
val_loader = DataLoader(val_dataset, batch_size=32)


In [38]:
model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(0.2),
    torch.nn.Linear(num_features, 2)
)

for name, param in model.named_parameters():
    if 'layer4.1' in name or 'fc' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
    
# Count total and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

Total Parameters: 23,512,130
Trainable Parameters: 4,466,690




In [39]:
model = model.cuda()

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=1e-4)

In [40]:
from tqdm import tqdm

# Train
num_epoch = 20

for epoch in tqdm(range(num_epoch)):
    # train
    model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    
    for images, labels in train_loader:
        images, labels = images.cuda(), labels.cuda()
        
        # Calculate
        outputs = model(images)
        loss = criterion(outputs, labels)
        train_loss += loss.item()
        
        # Backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        _, preds = torch.max(outputs, 1)
        train_correct += (preds==labels).sum().item()
        train_total += len(labels)
        
    # validation
    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, preds = torch.max(outputs, 1)
            val_correct += (preds==labels).sum().item()
            val_total += len(labels)
        
    print(f"Epoch : {epoch}")
    print(f"train loss : {train_loss/train_total:.6f} || train accuracy : {train_correct/train_total:.4f}")
    print(f"val loss : {val_loss/val_total:.6f} || val accuracy : {val_correct/val_total:.4f}")
    
    torch.save(model.state_dict(), f"epoch_{epoch}_valacc_{val_correct/val_total:.4f}.pth")

  5%|▌         | 1/20 [00:02<00:44,  2.35s/it]

Epoch : 0
train loss : 0.023813 || train accuracy : 0.5084
val loss : 0.020813 || val accuracy : 0.8922


 10%|█         | 2/20 [00:04<00:40,  2.25s/it]

Epoch : 1
train loss : 0.021206 || train accuracy : 0.5868
val loss : 0.018020 || val accuracy : 0.8922


 15%|█▌        | 3/20 [00:06<00:37,  2.23s/it]

Epoch : 2
train loss : 0.021485 || train accuracy : 0.5840
val loss : 0.016906 || val accuracy : 0.9020


 20%|██        | 4/20 [00:08<00:35,  2.23s/it]

Epoch : 3
train loss : 0.020308 || train accuracy : 0.6443
val loss : 0.015475 || val accuracy : 0.9118


 25%|██▌       | 5/20 [00:11<00:33,  2.23s/it]

Epoch : 4
train loss : 0.020010 || train accuracy : 0.6583
val loss : 0.013955 || val accuracy : 0.9118


 30%|███       | 6/20 [00:13<00:31,  2.22s/it]

Epoch : 5
train loss : 0.019354 || train accuracy : 0.6891
val loss : 0.014619 || val accuracy : 0.9118


 35%|███▌      | 7/20 [00:15<00:28,  2.17s/it]

Epoch : 6
train loss : 0.017832 || train accuracy : 0.7787
val loss : 0.014261 || val accuracy : 0.9118


 40%|████      | 8/20 [00:17<00:25,  2.14s/it]

Epoch : 7
train loss : 0.017122 || train accuracy : 0.7899
val loss : 0.013516 || val accuracy : 0.9314


 45%|████▌     | 9/20 [00:19<00:23,  2.18s/it]

Epoch : 8
train loss : 0.015693 || train accuracy : 0.8557
val loss : 0.012458 || val accuracy : 0.9412


 50%|█████     | 10/20 [00:22<00:21,  2.19s/it]

Epoch : 9
train loss : 0.014403 || train accuracy : 0.8936
val loss : 0.012457 || val accuracy : 0.9118


 55%|█████▌    | 11/20 [00:24<00:19,  2.14s/it]

Epoch : 10
train loss : 0.012981 || train accuracy : 0.9356
val loss : 0.012440 || val accuracy : 0.9216


 60%|██████    | 12/20 [00:26<00:16,  2.10s/it]

Epoch : 11
train loss : 0.011482 || train accuracy : 0.9524
val loss : 0.012202 || val accuracy : 0.9412


 65%|██████▌   | 13/20 [00:28<00:14,  2.07s/it]

Epoch : 12
train loss : 0.010357 || train accuracy : 0.9720
val loss : 0.012073 || val accuracy : 0.9314


 70%|███████   | 14/20 [00:30<00:12,  2.06s/it]

Epoch : 13
train loss : 0.009877 || train accuracy : 0.9692
val loss : 0.012439 || val accuracy : 0.9314


 75%|███████▌  | 15/20 [00:32<00:10,  2.11s/it]

Epoch : 14
train loss : 0.008875 || train accuracy : 0.9832
val loss : 0.011889 || val accuracy : 0.9118


 80%|████████  | 16/20 [00:34<00:08,  2.15s/it]

Epoch : 15
train loss : 0.008351 || train accuracy : 0.9860
val loss : 0.011125 || val accuracy : 0.9412


 85%|████████▌ | 17/20 [00:36<00:06,  2.18s/it]

Epoch : 16
train loss : 0.007917 || train accuracy : 0.9846
val loss : 0.011570 || val accuracy : 0.9216


 90%|█████████ | 18/20 [00:38<00:04,  2.14s/it]

Epoch : 17
train loss : 0.007261 || train accuracy : 0.9902
val loss : 0.011400 || val accuracy : 0.9510


 95%|█████████▌| 19/20 [00:40<00:02,  2.11s/it]

Epoch : 18
train loss : 0.006632 || train accuracy : 0.9972
val loss : 0.010959 || val accuracy : 0.9216


100%|██████████| 20/20 [00:43<00:00,  2.16s/it]

Epoch : 19
train loss : 0.006403 || train accuracy : 0.9902
val loss : 0.011226 || val accuracy : 0.9510





In [27]:
torch.save(model.state_dict(), 'sparsity_check.pth')