In [1]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import datasets, transforms, models
from torch_dct import dct_2d
sys.path.append(os.path.abspath(".."))
from data.ImageDataset import ImageDataset

In [None]:

class FFTResNet(nn.Module):
    def __init__(self, num_classes=1):
        super(FFTResNet, self).__init__()
        # Load a pretrained ResNet model
        self.resnet = models.resnet101(pretrained=True)

        # Modify the first convolutional layer to accept DCT input if needed
        self.resnet.conv1 = nn.Conv2d(
            6, 64, kernel_size=7, stride=2, padding=3, bias=False
        )  # Ensure it matches DCT input (3 channels)

        # Modify the output layer to match the number of classes
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_features, 128),  # Add an intermediate FC layer
            nn.ReLU(),
            nn.Linear(128, num_classes),  # Output layer
            nn.Sigmoid()  # For binary classification
        )

    def apply_fft_batch(self, x):
        assert len(x.shape) == 4, "Expected input tensor of shape (B, C, H, W)"
        real_parts = torch.stack([torch.real(torch.fft.fft2(x[:, c, :, :])) for c in range(x.shape[1])], dim=1)
        imag_parts = torch.stack([torch.imag(torch.fft.fft2(x[:, c, :, :])) for c in range(x.shape[1])], dim=1)
        # Concatenate real and imaginary parts along the channel dimension
        fft_images = torch.cat([real_parts, imag_parts], dim=1)  # (B, 6, H, W) if input has 3 channels
        return fft_images


    def forward(self, x):
        # Apply DCT transformation
        x = self.apply_fft_batch(x)
        # Pass the DCT-transformed images through ResNet
        return self.resnet(x)

In [3]:
model = FFTResNet(num_classes=1)
criterion = nn.BCELoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = model.to(device)



cuda


In [4]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),           
    transforms.ToTensor(),                    
    transforms.Normalize(                     
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

train_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/train.json",
    images_dir="/home/ec2-user/CS230Project/data/train",
    transform=transform,
)

val_dataset = ImageDataset(
    annotations_path="/home/ec2-user/CS230Project/data/annotations/val.json",
    images_dir="/home/ec2-user/CS230Project/data/val",
    transform=transform,
)


train_loader = DataLoader(train_dataset, batch_size=64, num_workers=7,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, num_workers=7, shuffle=False)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):

    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device).float()  


        outputs = model(images)
        outputs = outputs.view(-1)  
        labels = labels.view(-1)  
        loss = criterion(outputs, labels)

    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        predicted = (outputs > 0.5).float()  
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100. * correct / total
    print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")

    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device).float()

            outputs = model(images)
            outputs = outputs.view(-1)  
            labels = labels.view(-1)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            predicted = (outputs > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100. * correct / total
    print(f"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {val_accuracy:.2f}%")

    scheduler.step()

    checkpoint_path = f"/home/ec2-user/CS230Project/code/models/saved-weights/FFTcnn/fft_cnn_{epoch+1}.pth"
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")



Training Epoch 1/10:   8%|▊         | 55/690 [00:22<04:15,  2.49it/s]


KeyboardInterrupt: 

: 