In [1]:
import timm
import torch
import torch.nn as nn
import sys
import os
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import datasets, transforms
sys.path.append(os.path.abspath(".."))
from data.ImageDataset import ImageDataset
from timm import create_model
from torchvision.models import resnet101
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torch_dct import dct_2d

In [2]:
xception = timm.create_model('xception', pretrained=False)
xception.fc = nn.Sequential(
    nn.Linear(xception.fc.in_features, 512),
    nn.ReLU(),                             
    nn.Dropout(p=0.5),                     
    nn.Linear(512, 1),                     
    nn.Sigmoid()                          
)

class ViTBinaryClassifier(nn.Module):
    def __init__(self, model_name="vit_base_patch16_224", pretrained=False, num_classes=1):
        super(ViTBinaryClassifier, self).__init__()
        self.vit = timm.create_model(model_name, pretrained=pretrained, drop_rate=0.6, attn_drop_rate=0.5)
        in_features = self.vit.head.in_features
        self.vit.head = nn.Sequential(
            nn.Linear(in_features, num_classes),
            nn.Sigmoid()  # Sigmoid for binary classification
        )

    def forward(self, x):
        return self.vit(x)


vit = ViTBinaryClassifier()

swin = create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=1)

resnet = resnet101(pretrained=True)  
resnet.fc = nn.Linear(resnet.fc.in_features, 1)


class FFTResNet(nn.Module):
    def __init__(self, num_classes=1):
        super(FFTResNet, self).__init__()
        # Load a pretrained ResNet model
        self.resnet = resnet101(pretrained=False)

        # Modify the first convolutional layer to accept DCT input if needed
        self.resnet.conv1 = nn.Conv2d(
            6, 64, kernel_size=7, stride=2, padding=3, bias=False
        )  # Ensure it matches DCT input (3 channels)

        # Modify the output layer to match the number of classes
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 128),  # Add an intermediate FC layer
            nn.ReLU(),
            nn.Linear(128, num_classes),  # Output layer
            nn.Sigmoid()  # For binary classification
        )

    def apply_fft_batch(self, x):
        assert len(x.shape) == 4, "Expected input tensor of shape (B, C, H, W)"
        real_parts = torch.stack([torch.real(torch.fft.fft2(x[:, c, :, :])) for c in range(x.shape[1])], dim=1)
        imag_parts = torch.stack([torch.imag(torch.fft.fft2(x[:, c, :, :])) for c in range(x.shape[1])], dim=1)
        # Concatenate real and imaginary parts along the channel dimension
        fft_images = torch.cat([real_parts, imag_parts], dim=1)  # (B, 6, H, W) if input has 3 channels
        return fft_images


    def forward(self, x):
        # Apply DCT transformation
        x = self.apply_fft_batch(x)
        # Pass the DCT-transformed images through ResNet
        return self.resnet(x)
    
fft = FFTResNet()

efficient = efficientnet_b4(weights=EfficientNet_B4_Weights.DEFAULT)
efficient.classifier[1] = torch.nn.Linear(efficient.classifier[1].in_features, 1)



class DCTResNet(nn.Module):
    def __init__(self, num_classes=1):
        super(DCTResNet, self).__init__()
        # Load a pretrained ResNet model
        self.resnet = resnet101(pretrained=False)

        # Modify the first convolutional layer to accept DCT input if needed
        self.resnet.conv1 = nn.Conv2d(
            3, 64, kernel_size=7, stride=2, padding=3, bias=False
        )  # Ensure it matches DCT input (3 channels)

        # Modify the output layer to match the number of classes
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 128),  # Add an intermediate FC layer
            nn.ReLU(),
            nn.Linear(128, num_classes),  # Output layer
            nn.Sigmoid()  # For binary classification
        )

    def apply_dct_batch(self, x):
        """
        Applies DCT to a batch of images.
        x: Tensor of shape (B, C, H, W)
        """
        assert len(x.shape) == 4, "Expected input tensor of shape (B, C, H, W)"
        # Apply DCT to each channel of each image in the batch
        dct_images = torch.stack([dct_2d(x[:, c, :, :]) for c in range(x.shape[1])], dim=1)
        return dct_images

    def forward(self, x):
        # Apply DCT transformation
        x = self.apply_dct_batch(x)
        # Pass the DCT-transformed images through ResNet
        return self.resnet(x)
    
dct = DCTResNet()



  model = create_fn(
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dct.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/DCTcnn/dct_cnn_3.pth"))
dct.to(device)
dct.eval()

resnet.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/Resnet/Resnet_7.pth"))
resnet.to(device)
resnet.eval()

fft.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/FFTcnn/fft_cnn_3.pth"))
fft.to(device)
fft.eval()

swin.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/SwinTransformer/Swin_9.pth"))
swin.to(device)
swin.eval()

efficient.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/Efficientnet_b4/efficientnet_b4_10.pth"))
efficient.to(device)
efficient.eval()

xception.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/ExceptionNet/exception_net_9.pth"))
xception = xception.to(device)
xception.eval()

vit.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_6.pth"))
vit = vit.to(device)
vit.eval()



  dct.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/DCTcnn/dct_cnn_3.pth"))
  resnet.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/Resnet/Resnet_7.pth"))
  fft.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/FFTcnn/fft_cnn_3.pth"))
  swin.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/SwinTransformer/Swin_9.pth"))
  efficient.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/Efficientnet_b4/efficientnet_b4_10.pth"))
  xception.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/ExceptionNet/exception_net_9.pth"))
  vit.load_state_dict(torch.load("/home/ec2-user/CS230Project/code/models/saved-weights/ViT/ViT_6.pth"))


ViTBinaryClassifier(
  (vit): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (patch_drop): Identity()
    (norm_pre): Identity()
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (q_norm): Identity()
          (k_norm): Identity()
          (attn_drop): Dropout(p=0.5, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='n

In [4]:

class EnsembleModel(nn.Module):
    def __init__(self, models, transformations, num_classes=1, input_shape=(3, 500, 500), device="cpu"):
        super(EnsembleModel, self).__init__()
        assert len(models) == len(transformations), "Each model must have a corresponding transformation."

        self.device = torch.device(device)
        self.models = nn.ModuleList([model.to(self.device) for model in models])  # Move models to the device
        self.transformations = transformations
        self.feature_dims = []

        for model in self.models:
            for param in model.parameters():
                param.requires_grad = False

        # Dynamically compute feature dimensions
        dummy_input = torch.randn(1, *input_shape).to(self.device)  # Move dummy input to the device
        for model, transform in zip(self.models, self.transformations):
            with torch.no_grad():
                transformed_input = transform(dummy_input)
                features = model(transformed_input)

                # Flatten features if needed
                if len(features.shape) > 2:  # If output is 4D, apply global pooling
                    features = torch.flatten(features, start_dim=1)
            self.feature_dims.append(features.shape[1])

        total_features = sum(self.feature_dims)

        # Shared classification head
        self.head = nn.Sequential(
            nn.Linear(total_features, 256),  # Intermediate fully connected layer
            nn.ReLU(),
            nn.Dropout(0.9),
            nn.Linear(256, num_classes),
            nn.Sigmoid()
        ).to(self.device)  # Move classification head to the device

    def forward(self, x):
        x = x.to(self.device)

        # Apply transformations and extract flattened features from each model
        features = []
        for model, transform in zip(self.models, self.transformations):
            transformed_input = transform(x)
            output = model(transformed_input)

            # Flatten features if needed
            if len(output.shape) > 2:  # If output is 4D, apply global pooling
                output = torch.flatten(output, start_dim=1)
            features.append(output)

        combined_features = torch.cat(features, dim=1)  # Concatenate along feature dimension
        return self.head(combined_features)



In [5]:
resnet.fc = nn.Identity()
vit.head = nn.Identity()
efficient.classifier[1] = nn.Identity()
fft.resnet.fc = nn.Identity()
dct.resnet.fc = nn.Identity()
xception.fc = nn.Identity()
swin.head = nn.Identity()



models_list = [resnet, dct, fft, swin, efficient, xception, vit]

transformation1 = transforms.Compose([
    transforms.Resize((224, 224)),                         
    transforms.Normalize(                     
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

transformation2 = transforms.Compose([
    transforms.Resize((299, 299)),                          
    transforms.Normalize(                     
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

transformation3 = transforms.Compose([
    transforms.Resize((299, 299)),                  
    transforms.Normalize(                     
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

transformation4 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  
])

transformations = [transformation1, transformation1, transformation1, transformation1, transformation2, transformation3, transformation4]


In [6]:
model = EnsembleModel(models=models_list, transformations=transformations, device=device)


In [7]:
criterion = nn.BCELoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)

cuda


In [8]:
transform = transforms.Compose([
    transforms.Resize((500, 500)),           
    transforms.ToTensor()
])

train_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/train.json",
    images_dir="/home/ec2-user/CS230Project/data/train",
    transform=transform,
)

val_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/val.json",
    images_dir="/home/ec2-user/CS230Project/data/val",
    transform=transform,
)

train_loader = DataLoader(train_dataset, batch_size=64, num_workers=7,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=7, shuffle=False)

In [None]:
num_epochs = 10
best_val_acc = float("-inf")
for epoch in range(num_epochs):

    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device).float()  


        outputs = model(images)
        outputs = outputs.view(-1)  
        labels = labels.view(-1)  
        loss = criterion(outputs, labels)

    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        predicted = (outputs > 0.5).float()  
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100. * correct / total
    print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")

    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device).float()

            outputs = model(images)
            outputs = outputs.view(-1)  
            labels = labels.view(-1)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100. * correct / total
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {val_accuracy:.2f}%")

    scheduler.step()
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        checkpoint_path = f"/home/ec2-user/CS230Project/code/models/saved-weights/LinearEnsemble/linear_ensemble_{epoch+1}.pth"
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model saved to {checkpoint_path}")



Training Epoch 1/10:   0%|          | 0/690 [00:00<?, ?it/s]

Training Epoch 1/10: 100%|██████████| 690/690 [12:08<00:00,  1.06s/it]


Epoch 1, Train Loss: 0.0540, Accuracy: 98.76%


Validation:  24%|██▍       | 28/115 [00:30<01:25,  1.02it/s]