In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install rasterio tqdm

import os
import torch
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models.segmentation import deeplabv3_resnet50
import torch.optim as optim
import torch.nn as nn
import pandas as pd
from tqdm import tqdm  # Import tqdm for the progress bar

# Custom Dataset Class for 8-Channel Segmentation
class CustomSegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths=None, transform=None, is_test=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths if not is_test else None
        self.transform = transform
        self.is_test = is_test
        self.samples = []

        if not is_test:
            for img_path, mask_path in zip(image_paths, mask_paths):
                if os.path.exists(img_path) and os.path.exists(mask_path):
                    self.samples.append((img_path, mask_path))
        else:
            for img_path in image_paths:
                if os.path.exists(img_path):
                    self.samples.append((img_path, None))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, mask_path = self.samples[idx]
        preprocessing = Preprocessing(image_path)
        image = preprocessing.preprocess_image()
        image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)  # (C, H, W)

        if self.is_test:
            return image, {"image_id": torch.tensor([idx])}
        else:
            mask = plt.imread(mask_path)

            # Convert mask to grayscale (if it has extra channels)
            if mask.ndim == 3:
                mask = mask[..., 0]  # Take only the first channel
                
            # Create a writable copy of the mask
            mask = mask.copy()  # Make the mask writable

            # Convert mask values to integers (0 for background, 1 for class)
            mask[mask == -9999] = 0  # Set -9999 as background
            mask[mask == 1] = 1       # Set 1 as the foreground class

            # Ensure it's a 2D tensor (H, W) and correct dtype
            mask = torch.tensor(mask, dtype=torch.long).squeeze(0)

            if self.transform:
                image = self.transform(image)
                
            print(mask.shape)  # Check the shape of the mask

            return image, mask  # Image (C, H, W), Mask (H, W)

# Preprocessing Class for 8 Channels
class Preprocessing:
    def __init__(self, image_path):
        self.image_path = image_path

    def load_bands(self):
        with rasterio.open(self.image_path) as src:
            blue = src.read(1)
            green = src.read(2)
            red = src.read(3)
            nir = src.read(4)
            swir1 = src.read(5)
            swir2 = src.read(6)
        return blue, green, red, nir, swir1, swir2

    def preprocess_image(self):
        blue, green, red, nir, swir1, swir2 = self.load_bands()
        ndvi = self.compute_ndvi(red, nir)
        evi = self.compute_evi(nir, red, blue)
        normalized_bands = [self.normalize_band(band) for band in [blue, green, red, nir, swir1, swir2]]
        image = np.stack(normalized_bands + [ndvi, evi], axis=-1)  # Stack 8 channels
        return image

    def normalize_band(self, band):
        return (band - np.min(band)) / (np.max(band) - np.min(band))

    def compute_ndvi(self, red, nir):
        return (nir - red) / (nir + red + 1e-6)

    def compute_evi(self, nir, red, blue, g=2.5, c1=6, c2=7.5, l=1):
        return np.clip(g * (nir - red) / (nir + c1 * red - c2 * blue + l), 0, 1)

# Training and Model Setup
class DeepLabV3Model:
    def __init__(self, num_classes=2, device='cuda'):
        self.device = device
        self.model = deeplabv3_resnet50(pretrained=True)

        # Modify input layer to accept 8 channels
        in_features = self.model.backbone.conv1.in_channels
        self.model.backbone.conv1 = nn.Conv2d(8, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Modify output layer for segmentation classes
        self.model.classifier[4] = nn.Conv2d(256, num_classes, kernel_size=1)

        self.model.to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)

    def train(self, dataloader, num_epochs=10, checkpoint_interval=5):
        self.model.train()
        for epoch in range(num_epochs):
            running_loss = 0.0
            # Initialize tqdm for progress bar
            pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

            for images, masks in pbar:
                images, masks = images.to(self.device), masks.to(self.device)
                self.optimizer.zero_grad()
                outputs = self.model(images)['out']
                loss = self.criterion(outputs, masks)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()

                # Update the progress bar description with the current loss
                pbar.set_postfix({"Loss": running_loss / (pbar.n + 1)})

            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(dataloader):.4f}")
            
            # Save checkpoint every `checkpoint_interval` epochs
            if (epoch + 1) % checkpoint_interval == 0:
                self.save_checkpoint(epoch + 1)

    def evaluate(self, dataloader):
        self.model.eval()
        iou_scores = []
        with torch.no_grad():
            for images, masks in dataloader:
                images, masks = images.to(self.device), masks.to(self.device)
                outputs = self.model(images)['out']
                preds = torch.argmax(outputs, dim=1)
                intersection = (preds & masks).float().sum()
                union = (preds | masks).float().sum()
                iou_scores.append(intersection / union)
        mean_iou = sum(iou_scores) / len(iou_scores)
        print(f"Mean IoU: {mean_iou:.4f}")

    def save_checkpoint(self, epoch):
        checkpoint_path = f"checkpoint_epoch_{epoch}.pth"
        torch.save(self.model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

# Load dataset
train_csv = pd.read_csv("/kaggle/input/train-test/train_ds.csv")
test_csv = pd.read_csv("/kaggle/input/train-test/test_ds.csv")

train_image_paths = train_csv["Input"].tolist()
train_mask_paths = train_csv["Label"].tolist()
test_image_paths = test_csv["Input"].tolist()

# Data Transformations
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
])
%cd /kaggle/input/geo-ai-hack

# Create Datasets and Dataloaders
train_dataset = CustomSegmentationDataset(image_paths=train_image_paths, mask_paths=train_mask_paths, transform=transform)
test_dataset = CustomSegmentationDataset(image_paths=test_image_paths, is_test=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Initialize and train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepLabV3Model(device=device)
model.train(train_dataloader, num_epochs=10, checkpoint_interval=5)

# Evaluate the model
model.evaluate(test_dataloader)

/kaggle/input/geo-ai-hack


Epoch 1/10:   0%|          | 0/1304 [00:00<?, ?it/s]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 1/1304 [00:00<20:35,  1.05it/s, Loss=0.675]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


  return np.clip(g * (nir - red) / (nir + c1 * red - c2 * blue + l), 0, 1)


torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 3/1304 [00:02<18:23,  1.18it/s, Loss=0.594]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 4/1304 [00:03<17:49,  1.22it/s, Loss=0.563]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 5/1304 [00:04<17:43,  1.22it/s, Loss=0.532]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   0%|          | 6/1304 [00:05<17:47,  1.22it/s, Loss=0.504]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 7/1304 [00:05<17:42,  1.22it/s, Loss=0.475]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 8/1304 [00:06<17:44,  1.22it/s, Loss=0.447]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 9/1304 [00:07<17:31,  1.23it/s, Loss=0.423]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 10/1304 [00:08<18:04,  1.19it/s, Loss=0.401]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 11/1304 [00:09<18:03,  1.19it/s, Loss=0.38] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 12/1304 [00:09<17:50,  1.21it/s, Loss=0.36]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 13/1304 [00:10<17:47,  1.21it/s, Loss=0.343]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 14/1304 [00:11<17:48,  1.21it/s, Loss=0.327]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 15/1304 [00:12<17:37,  1.22it/s, Loss=0.312]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|          | 16/1304 [00:13<17:43,  1.21it/s, Loss=0.298]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 17/1304 [00:14<17:37,  1.22it/s, Loss=0.285]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 18/1304 [00:14<17:40,  1.21it/s, Loss=0.273]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   1%|▏         | 19/1304 [00:15<17:38,  1.21it/s, Loss=0.262]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 20/1304 [00:16<17:37,  1.21it/s, Loss=0.251]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 21/1304 [00:17<17:39,  1.21it/s, Loss=0.242]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 22/1304 [00:18<17:49,  1.20it/s, Loss=0.233]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 23/1304 [00:19<17:37,  1.21it/s, Loss=0.224]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 24/1304 [00:19<17:28,  1.22it/s, Loss=0.217]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 25/1304 [00:20<17:32,  1.22it/s, Loss=0.21] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 26/1304 [00:21<17:41,  1.20it/s, Loss=0.203]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 27/1304 [00:22<17:43,  1.20it/s, Loss=0.197]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 28/1304 [00:23<17:36,  1.21it/s, Loss=0.191]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 29/1304 [00:24<17:52,  1.19it/s, Loss=0.185]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 30/1304 [00:24<17:43,  1.20it/s, Loss=0.181]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 31/1304 [00:25<17:43,  1.20it/s, Loss=0.176]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   2%|▏         | 32/1304 [00:26<17:43,  1.20it/s, Loss=0.171]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 33/1304 [00:27<17:46,  1.19it/s, Loss=0.167]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 34/1304 [00:28<17:36,  1.20it/s, Loss=0.162]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 35/1304 [00:29<17:25,  1.21it/s, Loss=0.159]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 36/1304 [00:29<17:44,  1.19it/s, Loss=0.155]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 37/1304 [00:30<17:41,  1.19it/s, Loss=0.151]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 38/1304 [00:31<17:33,  1.20it/s, Loss=0.148]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 39/1304 [00:32<17:40,  1.19it/s, Loss=0.144]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 40/1304 [00:33<17:41,  1.19it/s, Loss=0.141]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 41/1304 [00:34<17:49,  1.18it/s, Loss=0.138]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 42/1304 [00:34<17:36,  1.20it/s, Loss=0.135]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 43/1304 [00:35<17:29,  1.20it/s, Loss=0.132]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 44/1304 [00:36<17:31,  1.20it/s, Loss=0.13] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   3%|▎         | 45/1304 [00:37<17:26,  1.20it/s, Loss=0.127]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▎         | 46/1304 [00:38<17:15,  1.22it/s, Loss=0.125]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▎         | 47/1304 [00:39<17:08,  1.22it/s, Loss=0.122]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▎         | 48/1304 [00:39<17:04,  1.23it/s, Loss=0.12] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 49/1304 [00:40<17:08,  1.22it/s, Loss=0.118]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 50/1304 [00:41<17:05,  1.22it/s, Loss=0.116]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 51/1304 [00:42<17:17,  1.21it/s, Loss=0.114]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 52/1304 [00:43<17:15,  1.21it/s, Loss=0.112]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 53/1304 [00:43<17:13,  1.21it/s, Loss=0.11] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 54/1304 [00:44<17:08,  1.21it/s, Loss=0.108]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 55/1304 [00:45<17:17,  1.20it/s, Loss=0.107]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 56/1304 [00:46<17:03,  1.22it/s, Loss=0.105]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 57/1304 [00:47<16:51,  1.23it/s, Loss=0.104]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   4%|▍         | 58/1304 [00:48<16:49,  1.23it/s, Loss=0.102]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 59/1304 [00:48<16:49,  1.23it/s, Loss=0.101]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 60/1304 [00:49<16:53,  1.23it/s, Loss=0.0992]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 61/1304 [00:50<16:46,  1.24it/s, Loss=0.0978]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 62/1304 [00:51<16:43,  1.24it/s, Loss=0.0964]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 63/1304 [00:52<16:54,  1.22it/s, Loss=0.095] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 64/1304 [00:52<16:46,  1.23it/s, Loss=0.0937]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▍         | 65/1304 [00:53<16:36,  1.24it/s, Loss=0.0924]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 66/1304 [00:54<16:29,  1.25it/s, Loss=0.0912]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 67/1304 [00:55<16:23,  1.26it/s, Loss=0.09]  

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 68/1304 [00:56<16:29,  1.25it/s, Loss=0.0888]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 69/1304 [00:56<16:32,  1.24it/s, Loss=0.0876]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 70/1304 [00:57<16:24,  1.25it/s, Loss=0.0865]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   5%|▌         | 71/1304 [00:58<16:18,  1.26it/s, Loss=0.0855]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 72/1304 [00:59<16:13,  1.26it/s, Loss=0.0845]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 73/1304 [01:00<16:23,  1.25it/s, Loss=0.0835]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 74/1304 [01:00<16:15,  1.26it/s, Loss=0.0825]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 75/1304 [01:01<16:36,  1.23it/s, Loss=0.0815]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 76/1304 [01:02<16:29,  1.24it/s, Loss=0.0806]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 77/1304 [01:03<16:29,  1.24it/s, Loss=0.0797]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 78/1304 [01:04<16:28,  1.24it/s, Loss=0.0788]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 79/1304 [01:04<16:27,  1.24it/s, Loss=0.0779]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 80/1304 [01:05<16:23,  1.24it/s, Loss=0.0771]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▌         | 81/1304 [01:06<16:23,  1.24it/s, Loss=0.0763]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▋         | 82/1304 [01:07<16:22,  1.24it/s, Loss=0.0755]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▋         | 83/1304 [01:08<16:22,  1.24it/s, Loss=0.0747]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   6%|▋         | 84/1304 [01:08<16:17,  1.25it/s, Loss=0.0739]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 85/1304 [01:09<16:21,  1.24it/s, Loss=0.0732]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 86/1304 [01:10<16:28,  1.23it/s, Loss=0.0724]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 87/1304 [01:11<16:31,  1.23it/s, Loss=0.0717]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 88/1304 [01:12<16:28,  1.23it/s, Loss=0.071] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 89/1304 [01:12<16:27,  1.23it/s, Loss=0.0703]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 90/1304 [01:13<16:22,  1.24it/s, Loss=0.0696]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 91/1304 [01:14<16:17,  1.24it/s, Loss=0.0689]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 92/1304 [01:15<16:20,  1.24it/s, Loss=0.0684]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 93/1304 [01:16<16:49,  1.20it/s, Loss=0.0677]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 94/1304 [01:17<16:30,  1.22it/s, Loss=0.0671]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 95/1304 [01:17<16:20,  1.23it/s, Loss=0.0665]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 96/1304 [01:18<16:13,  1.24it/s, Loss=0.0659]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   7%|▋         | 97/1304 [01:19<16:13,  1.24it/s, Loss=0.0653]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 98/1304 [01:20<16:10,  1.24it/s, Loss=0.0647]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 99/1304 [01:21<16:11,  1.24it/s, Loss=0.0641]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 100/1304 [01:21<16:07,  1.24it/s, Loss=0.0635]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 101/1304 [01:22<16:02,  1.25it/s, Loss=0.063] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 102/1304 [01:23<16:05,  1.24it/s, Loss=0.0624]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 103/1304 [01:24<16:07,  1.24it/s, Loss=0.0619]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 104/1304 [01:25<16:15,  1.23it/s, Loss=0.0614]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 105/1304 [01:25<16:16,  1.23it/s, Loss=0.0609]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 106/1304 [01:26<16:09,  1.24it/s, Loss=0.0603]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 107/1304 [01:27<16:16,  1.23it/s, Loss=0.0599]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 108/1304 [01:28<16:17,  1.22it/s, Loss=0.0594]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 109/1304 [01:29<16:19,  1.22it/s, Loss=0.059] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   8%|▊         | 110/1304 [01:30<16:15,  1.22it/s, Loss=0.0585]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▊         | 111/1304 [01:30<16:09,  1.23it/s, Loss=0.058] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▊         | 112/1304 [01:31<16:08,  1.23it/s, Loss=0.0575]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▊         | 113/1304 [01:32<16:10,  1.23it/s, Loss=0.0571]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▊         | 114/1304 [01:33<16:17,  1.22it/s, Loss=0.0567]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 115/1304 [01:34<16:07,  1.23it/s, Loss=0.0563]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 116/1304 [01:34<16:03,  1.23it/s, Loss=0.0559]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 117/1304 [01:35<16:01,  1.23it/s, Loss=0.0554]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 118/1304 [01:36<16:02,  1.23it/s, Loss=0.0551]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 119/1304 [01:37<16:15,  1.21it/s, Loss=0.0547]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 120/1304 [01:38<16:08,  1.22it/s, Loss=0.0544]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 121/1304 [01:39<16:08,  1.22it/s, Loss=0.054] 

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 122/1304 [01:39<16:20,  1.21it/s, Loss=0.0536]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:   9%|▉         | 123/1304 [01:40<16:10,  1.22it/s, Loss=0.0533]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:  10%|▉         | 124/1304 [01:41<16:07,  1.22it/s, Loss=0.0529]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:  10%|▉         | 125/1304 [01:42<16:04,  1.22it/s, Loss=0.0526]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:  10%|▉         | 126/1304 [01:43<16:02,  1.22it/s, Loss=0.0522]

torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])
torch.Size([256, 256])


Epoch 1/10:  10%|▉         | 127/1304 [01:43<15:55,  1.23it/s, Loss=0.0519]

In [None]:
# Evaluate the model and save predictions to CSV
def evaluate_and_save(model, dataloader, output_file):
    model.model.eval()
    results = []
    
    with torch.no_grad():
        for images, targets in dataloader:
            images = [image.to(model.device) for image in images]
            predictions = model.model(images)
            
            for idx, prediction in enumerate(predictions):
                image_id = targets[idx]["image_id"].item()
                if len(prediction["scores"]) > 0 and prediction["scores"][0].item() > 0.5:  # Seuil de confiance
                    label = 1  # Classe 1 (prédiction positive)
                else:
                    label = 0  # Classe 0 (prédiction négative)
                
                results.append([image_id, label])

    with open(output_file, mode="w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["id","prediction"])
        writer.writerows(results)
    
    print(f"Predictions saved to {output_file}")

# Call evaluation function
evaluate_and_save(model, test_dataloader, OUTPUT_FILE)
