In [1]:
from tqdm import tqdm
from loguru import logger
import argparse
import torch
import numpy as np
from tqdm import tqdm
from ruamel.yaml import YAML

yaml = YAML()
with open("config.yaml") as f:
    config = yaml.load(f)


In [2]:
from torch.utils.data import Dataset
import h5py

logger.add("debugging.log")

class rcc_3d_dataset(Dataset):
    def __init__(self, list_id, config):
        self.config = config
        self.sub_sample_freq = 5

        with h5py.File(config['path_data'], 'r', swmr=True) as f:
            start_indexes = []
            accumulated_count = 0
            new_list_id = []
            for id in list_id:
                try: 
                    shape_pre = f[id]['image']['precontrast_r']['Axial'].shape
                    shape_post = f[id]['image']['post_50sec']['Axial'].shape
                    shape_late = f[id]['image']['post_5min_r']['Axial'].shape
                    shape_mask = f[id]['segmentation']['post_50sec']['Axial'].shape
                except KeyError:
                    continue

                assert(shape_pre == shape_post == shape_late == shape_mask)
                start_indexes.append(accumulated_count)
                accumulated_count += (shape_pre[0])//self.sub_sample_freq
                new_list_id.append(id)

        self.list_id = new_list_id
        self.start_indexes = start_indexes ##
        self.tot_count = accumulated_count ##
        self.path_data = config['path_data']

    def __len__(self):
        return self.tot_count
    
    def __getitem__(self, idx):
        id_index = np.digitize(idx, self.start_indexes)-1

        # random slice
        slice_index = (idx - self.start_indexes[id_index]) * self.sub_sample_freq + np.random.randint(0, 5)
        # random left or right
        orientation = np.random.randint(0, 2) 

        image_x = h5py.File(self.path_data, 'r', swmr=True)[self.list_id[id_index]]['image']['precontrast_r']['Axial'][slice_index:slice_index+1, 167:423, orientation*256:(orientation+1)*256]
        image_y = h5py.File(self.path_data, 'r', swmr=True)[self.list_id[id_index]]['image']['post_50sec']['Axial'][slice_index:slice_index+1, 167:423, orientation*256:(orientation+1)*256]

        image_x = (image_x/2000).astype(np.float32)
        image_y = (image_y/2000).astype(np.float32)

        image_x = np.clip(image_x, 0, 1)
        image_y = np.clip(image_y, 0, 1)    

        # -1, 1로 넣어줘야하는 듯
        image_x = image_x * 2 -1
        image_y = image_y * 2 -1
        return torch.from_numpy(image_x), torch.from_numpy(image_y)

def struct_dataset(phase,):
    yaml = YAML()
    with open('/home/synergyai/jth/rcc-classification-research/experiments/config/session_240621.yaml') as f:
        config = yaml.load(f)

    split_dict = {}
    with open(config['path_split'],'r') as f_split:
        split = f_split.read().splitlines()
        for i, split_name in enumerate(['train', 'valid', 'test']):
            string = split[i]
            string = string.replace('[','').replace(']','').replace("'", "")
            id_list = string.split(', ')
            split_dict[split_name] = id_list
    
    if phase == 'train':
        list_id = split_dict['train']
    elif phase == 'val':
        list_id = split_dict['valid']

    return rcc_3d_dataset(list_id, config)
    

In [3]:
train_ds = struct_dataset('train')
val_ds = struct_dataset('val')
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=config['train']['batch_size'], shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=config['train']['batch_size'], shuffle=True, num_workers=4)

In [4]:
x0, y0 = train_ds.__getitem__(1)
#(b, c, h, w) 

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#seed=1024, resume=False, image_size=256, num_channels=2, centered=True, 
# use_geometric=False, beta_min=0.1, beta_max=20.0, num_channels_dae=64, 
# n_mlp=3, ch_mult=[1, 1, 2, 2, 4, 4], num_res_blocks=2, attn_resolutions=(16,), 
# dropout=0.0, resamp_with_conv=True, conditional=True, fir=True, fir_kernel=[1, 3, 3, 1], 
# skip_rescale=True, resblock_type='biggan', progressive='none', progressive_input='residual', 
# progressive_combine='sum', embedding_type='positional', fourier_scale=16.0, not_use_tanh=False, 
# exp='exp_syndiff', input_path='/home/synergyai/jth/ct-translation-research/input', 
# output_path='/home/synergyai/jth/ct-translation-research/output', nz=100, num_timesteps=4, 
# z_emb_dim=256, t_emb_dim=256, batch_size=1, num_epoch=500, ngf=64, lr_g=0.00016, lr_d=0.0001, 
# beta1=0.5, beta2=0.9, no_lr_decay=False, use_ema=True, ema_decay=0.999, r1_gamma=1.0, 
# lazy_reg=10, save_content=True, save_content_every=10, save_ckpt_every=10, lambda_l1_loss=0.5, 
# num_proc_node=1, num_process_per_node=1, node_rank=0, local_rank=0, master_address='127.0.0.1', 
# contrast1='T1', contrast2='T2', port_num='6021', world_size=1


In [6]:
# https://github.com/changzy00/pytorch-attention/blob/master/vision_transformers/ViT.py
class Attention(nn.Module):
    def __init__(self, dim, num_heads=4, qkv_bias=False, attn_drop=0, proj_drop=0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop) if attn_drop else nn.Identity()
        self.proj_drop = nn.Dropout(proj_drop) if proj_drop else nn.Identity()
        self.norm = nn.GroupNorm(num_groups = min(dim//4, 32), num_channels = dim, eps = 1e-6)

    def forward(self, x):
        # batch, sequence, channel
        B, N, C = x.shape
        x = self.norm(x.transpose(1,2)).transpose(1,2)
        #x = self.norm(x)
        # batch, sequence, (q/k/v), head, head_dim
        # -> (q/k/v), batch, head, sequence, head_dim
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # q,k,v : batch, head, sequence, head_dim
        q, k, v = qkv.unbind(0)
        # batch, head, sequence_q, sequence_k -> batch, head, sequence_k, sequence_q / scale by sqrt(dim)
        attn = (q @ k.transpose(-1, -2)) * self.scale
        # batch, head, sequence_k, sequence_q / softmax along sequence_q
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # (batch, head, sequence_k, sequence_q) * (batch, head, sequence, head_dim)
        # -> (batch, head, sequence_k, head_dim) -> (batch, sequence_k, head, head_dim) -> (batch, sequence_k, dim)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

attn = Attention(dim = 16, num_heads=2)
x = torch.randn(10, 128, 16)
x = attn(x)
print(x.shape)

torch.Size([10, 128, 16])


In [1]:
import os
os.environ['TORCH_CUDA_ARCH_LIST'] = '8.6'
from upfirdn2d import upfirdn2d

In [None]:

def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
  """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.

     Padding is performed only once at the beginning, not between the
     operations.
     The fused op is considerably more efficient than performing the same
     calculation
     using standard TensorFlow ops. It supports gradients of arbitrary order.
     Args:
       x:            Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
         C]`.
       w:            Weight tensor of the shape `[filterH, filterW, inChannels,
         outChannels]`. Grouped convolution can be performed by `inChannels =
         x.shape[0] // numGroups`.
       k:            FIR filter of the shape `[firH, firW]` or `[firN]`
         (separable). The default is `[1] * factor`, which corresponds to
         nearest-neighbor upsampling.
       factor:       Integer upsampling factor (default: 2).
       gain:         Scaling factor for signal magnitude (default: 1.0).

     Returns:
       Tensor of the shape `[N, C, H * factor, W * factor]` or
       `[N, H * factor, W * factor, C]`, and same datatype as `x`.
  """

  assert isinstance(factor, int) and factor >= 1

  # Check weight shape.
  assert len(w.shape) == 4
  convH = w.shape[2]
  convW = w.shape[3]
  inC = w.shape[1]
  outC = w.shape[0]

  assert convW == convH

  # Setup filter kernel.
  if k is None:
    k = [1] * factor
  k = _setup_kernel(k) * (gain * (factor ** 2))
  p = (k.shape[0] - factor) - (convW - 1)

  stride = (factor, factor)

  # Determine data dimensions.
  stride = [1, 1, factor, factor]
  output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
  output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
                    output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
  assert output_padding[0] >= 0 and output_padding[1] >= 0
  num_groups = _shape(x, 1) // inC

  # Transpose weights.
  w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
  w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
  w = torch.reshape(w, (num_groups * inC, -1, convH, convW))

  x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)

  return upfirdn2d(x, torch.tensor(k, device=x.device),
                   pad=((p + 1) // 2 + factor - 1, p // 2 + 1))


In [None]:
def get_kernalized_filter(k=[1, 3, 3, 1], gain = 1, factor = 2):
    k = np.asarray(k, dtype=np.float32)
    if k.ndim == 1:
        k = np.outer(k, k)
    k /= np.sum(k)
    return k*(gain*(factor**2))

# custom upsample with fir filtration
class conv_upscale_with_fir_sampling(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size = 3, fir_params = (1,3,3,1), scale_factor = 2):
        super().__init__()
        self.kernalized_filter = get_kernalized_filter(fir_params, factor=scale_factor)
        padding_value = (self.kernalized_filter.shape[0] - scale_factor) // (kernel_size-1)
        stride = [1,1,scale_factor,scale_factor]

        self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel_size, kernel_size))
        

        self.conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        # (batch, channel, height, width)
        x.shape()



        self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
                                                 kernel=3, up=True,
                                                 resample_kernel=fir_kernel,
                                                 use_bias=True,
                                                 kernel_init=default_init())
        

    def __init__(self, dim, scale_factor=2, mode='bilinear'):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)
        self.norm = nn.GroupNorm(num_groups = min(dim//4, 32), num_channels = dim, eps = 1e-6)


    def forward(self, x):
        x = self.upsample(x)
        x = self.norm(x.transpose(1,2)).transpose(1,2)
        return x

In [None]:
#attn block, upsample, downsample, pyramid downsample, resnetblockbiggan

In [30]:

# (B,C,H,W)
x = torch.rand(10,128,16,16,)
# (B,C,patch_H, patch_W)
attn(x)

TypeError: MultiheadAttention.forward() missing 2 required positional arguments: 'key' and 'value'

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)
    
class NCSN(nn.Module):
    def __init__(self):
        super(NCSN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.residual_blocks = nn.Sequential(
            *[ResNetBlock(256, 256) for _ in range(6)]
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )



In [None]:





class ResNetGenerator(nn.Module):
    def __init__(self):
        super(ResNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.residual_blocks = nn.Sequential(
            *[ResNetBlock(256, 256) for _ in range(6)]
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

# Example usage
input_image = torch.randn(1, 1, 256, 256)  # Batch size of 1, single channel 256x256 image
gen = ResNetGenerator()
output_image = gen(input_image)
print(output_image.shape)

In [None]:
@torch.nograd()
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    train_loss = 0
    for x, y in tqdm(train_loader):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return train_loss/len(train_loader)

In [None]:
def initialize_weights(shape, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu', variance=1.0):
    def _compute_fans(shape, in_axis=1, out_axis=0):
        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
        fan_in = shape[in_axis] * receptive_field_size
        fan_out = shape[out_axis] * receptive_field_size
        return fan_in, fan_out

    # Compute fan_in and fan_out
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)

    # Use fan_in for the denominator in the scaling factor
    denominator = fan_in

    # Calculate the scaling factor
    scaling_factor = np.sqrt(3 * variance / denominator)

    # Generate the weights
    weights = (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * scaling_factor

    return weights

In [None]:
# _compute_fans(shape, in_axis=1, out_axis=0):
# fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
# receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
# fan_in = shape[in_axis] * receptive_field_size
# fan_out = shape[out_axis] * receptive_field_size
# denominator = fan_in
# (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
#

In [21]:
for x, y in train_loader:
    print(x.shape, y.shape)
    break

torch.Size([4, 1, 256, 256]) torch.Size([4, 1, 256, 256])
