## Скачка данных и импорт пакетов

In [None]:
!mkdir ~/.kaggle
!pip install kaggle
!cp kaggle.json ~/.kaggle/kaggle.json
!kaggle datasets download -d paultimothymooney/chest-xray-pneumonia

Downloading chest-xray-pneumonia.zip to /content
100% 2.29G/2.29G [00:36<00:00, 84.2MB/s]
100% 2.29G/2.29G [00:36<00:00, 67.7MB/s]


In [None]:
!unzip -qq chest-xray-pneumonia.zip

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-0.7.2-py3-none-any.whl (397 kB)
[?25l[K     |▉                               | 10 kB 21.8 MB/s eta 0:00:01[K     |█▋                              | 20 kB 21.3 MB/s eta 0:00:01[K     |██▌                             | 30 kB 17.5 MB/s eta 0:00:01[K     |███▎                            | 40 kB 14.1 MB/s eta 0:00:01[K     |████▏                           | 51 kB 13.0 MB/s eta 0:00:01[K     |█████                           | 61 kB 14.9 MB/s eta 0:00:01[K     |█████▊                          | 71 kB 13.0 MB/s eta 0:00:01[K     |██████▋                         | 81 kB 13.6 MB/s eta 0:00:01[K     |███████▍                        | 92 kB 14.8 MB/s eta 0:00:01[K     |████████▎                       | 102 kB 15.9 MB/s eta 0:00:01[K     |█████████                       | 112 kB 15.9 MB/s eta 0:00:01[K     |██████████                      | 122 kB 15.9 MB/s eta 0:00:01[K     |██████████▊                     | 133 kB 15.9 

In [None]:
import os

from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset


import math

import numpy as np

import torchvision
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, Grayscale
from torchmetrics import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

from PIL import Image

import matplotlib.pyplot as plt

%matplotlib inline

## Загрузка данных


In [None]:
UPSCALE_FACTOR = 4
CROP_SIZE = 512

In [None]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return Compose([
        Resize(1024),
        Grayscale(),
        RandomCrop(crop_size),
        ToTensor(),
    ])


def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Grayscale(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])


class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames =  [os.path.join(r,file) for r,d,f in os.walk(dataset_dir) for file in f if is_image_file(file)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)

In [None]:
train_set = TrainDatasetFromFolder("chest_xray/train", crop_size=CROP_SIZE,
                                   upscale_factor=UPSCALE_FACTOR)
trainloader = DataLoader(train_set, batch_size=2, num_workers=4, shuffle=True)

  "Argument interpolation should be of type InterpolationMode instead of int. "
  cpuset_checked))


## Модель

In [None]:
class RWMAB(nn.Module):
  def __init__(self, in_channels):
    self.part1 = nn.Sequential(
        nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(in_channels, in_channels, (3, 3), stride=1, padding=1)
    )
    self.part2 = nn.Sequential(
        nn.Conv2d(in_channels, in_channels, (1, 1), stride=1, padding=0),
        nn.Sigmoid()
    )

  def forward(self, x):
    x1 = self.part1(x)
    x2 = self.part2(x)

    return x2*x1+x


class ShortResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()

    self.layers = nn.ModuleList([RWMAB(in_channels) for _ in range(16)])

  def forward(self, x):
    x1 = x.clone()

    for layer in self.layers:
      x1 = layer(x1)
    
    return x + x1
    

class Generator(nn.Module):
    def __init__(self, in_channels=1, blocks=8):
      super().__init__()
      self.conv = nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1)

      self.short_blocks = nn.ModuleList(
          [ShortResidualBlock(64) for _ in range(blocks)]
      )

      self.conv2 = nn.Conv2d(64, 64, (1, 1), stride=1, padding=0)

      self.conv3 = nn.Sequential(
          nn.Conv2d(128, 256, (3, 3), stride=1, padding=1),
          nn.PixelShuffle(2),
          nn.Conv2d(64, 256, (3, 3), stride=1, padding=1),
          nn.PixelShuffle(2), 
          nn.Conv2d(64, 1, (1, 1), stride=1, padding=0),
          nn.Sigmoid(),
      )

    def forward(self, x):
      x = self.conv(x)
      x1 = x.clone()

      for layer in self.short_blocks:
          x1 = layer(x1)
      x = torch.cat([self.conv2(x1), x], dim=1)
      x = self.conv3(x)
      return x

In [None]:
class D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3), stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
        )

    def forward(self, x):

        return self.layer(x)


class Discriminator(nn.Module):
    def __init__(self, img_size, in_channels=1):
        super().__init__()

        self.conv_1_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
        )

        self.block_1_1 = D_Block(64, 64, stride=2)
        self.block_1_2 = D_Block(64, 128, stride=1)
        self.block_1_3 = D_Block(128, 128)

        self.conv_2_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, (3, 3), stride=1, padding=1), nn.LeakyReLU()
        )

        self.block_2_2 = D_Block(64, 128, stride=1)

        self.block3 = D_Block(256, 256, stride=1)
        self.block4 = D_Block(256, 256)
        self.block5 = D_Block(256, 512, stride=1)
        self.block6 = D_Block(512, 512)
        self.block7 = D_Block(512, 1024)
        self.block8 = D_Block(1024, 1024)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(1024 * img_size[0] * img_size[1] // 256, 100)
        self.fc2 = nn.Linear(100, 2)

        self.relu = nn.LeakyReLU(negative_slope=0.2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x1, x2):

        x_1 = self.block_1_3(self.block_1_2(self.block_1_1(self.conv_1_1(x1))))
        x_2 = self.block_2_2(self.conv_2_1(x2))

        x = torch.cat([x_1, x_2], dim=1)
        x = self.block8(
            self.block7(self.block6(self.block5(self.block4(self.block3(x)))))
        )

        x = self.flatten(x)


        x = self.fc1(x)
        x = self.fc2(self.relu(x))

        return self.sigmoid(x)

## Обучение, ваше решение

Вам требуется имплементировать лосс функции, обучение, задать все оптимизаторы и прочие параметры как в srgan, а также, посчитать метрики. 