In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import os
import cv2
import wandb
import time
import torch
import torch.nn as nn
import numpy as np
import torchvision
import matplotlib
import matplotlib.pyplot as plt
import albumentations as A
import torch.optim as optim
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from albumentations.pytorch import ToTensorV2
import torch.nn.functional as F
from tqdm.auto import tqdm

In [11]:
original_height=384
original_width =384
#original_height=224
#original_width =224

transformations = A.Compose([
    #A.Resize(224,224),
    A.OneOf([
        A.RandomSizedCrop(min_max_height=(50, 101), height=original_height, width=original_width, p=0.5),
        A.PadIfNeeded(min_height=original_height, min_width=original_width, p=0.5)
    ], p=1),    
    A.VerticalFlip(p=0.5),              
    A.RandomRotate90(p=0.5),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5, p=1)                  
        ], p=0.8),
    #A.CLAHE(p=0.8),
    #A.RandomBrightnessContrast(p=0.8),    
    A.RandomGamma(p=0.8),
    ])
class CloudDataset(Dataset):
    def __init__(self, r_dir, g_dir, b_dir, nir_dir, gt_dir, pytorch=True,transform=False):
        super().__init__()
        
        # Loop through the files in red folder and combine, into a dictionary, the other bands
        self.files = [self.combine_files(f, g_dir, b_dir, nir_dir, gt_dir) for f in r_dir.iterdir() if not f.is_dir()]
        self.pytorch = pytorch
        self.transform = transform
        
    def combine_files(self, r_file: Path, g_dir, b_dir,nir_dir, gt_dir):
        
        files = {'red': r_file, 
                 'green':g_dir/r_file.name.replace('red', 'green'),
                 'blue': b_dir/r_file.name.replace('red', 'blue'), 
                 'nir': nir_dir/r_file.name.replace('red', 'nir'),
                 'gt': gt_dir/r_file.name.replace('red', 'gt')}

        return files
                                       
    def __len__(self):
        
        return len(self.files)
     
    def open_as_array(self, idx, invert=False, include_nir=True):

        raw_rgb = np.stack([np.array(Image.open(self.files[idx]['red'])),
                            np.array(Image.open(self.files[idx]['green'])),
                            np.array(Image.open(self.files[idx]['blue'])),
                           ], axis=2)
    
        if include_nir:
            nir = np.expand_dims(np.array(Image.open(self.files[idx]['nir'])), 2)
            raw_rgb = np.concatenate([nir,raw_rgb], axis=2)
    
        if invert:
            raw_rgb = raw_rgb.transpose((2,0,1))
    
        # normalize
        return (raw_rgb / np.iinfo(raw_rgb.dtype).max)
    

    def open_mask(self, idx, add_dims=False):
        
        raw_mask = np.array(Image.open(self.files[idx]['gt']))
        raw_mask = np.where(raw_mask==255, 1, 0)
        
        return np.expand_dims(raw_mask, 0) if add_dims else raw_mask
    
    def __getitem__(self, idx):
                    
        x = self.open_as_array(idx, invert=False, include_nir=True)
        
        y = self.open_mask(idx, add_dims=False)
       
        if self.transform is not None:
            augmented = self.transform(image=x,mask=y)
            x=augmented['image']#,dtype=torch.float64)
            y=augmented['mask']#dtype=torch.torch.float32) 
            
        x=torch.from_numpy(x)#,dtype=torch.float64)
        y=torch.tensor(y,dtype=torch.float32) 
        x = x.permute(2,0,1)
        #y=torch.tensor(y,dtype=torch.float32)
        
        return x.double(), y.unsqueeze(0)
    
    def open_as_pil(self, idx):
        
        arr = 256*self.open_as_array(idx)
        
        return Image.fromarray(arr.astype(np.uint8), 'RGB')
    
    def __repr__(self):
        s = 'Dataset class with {} files'.format(self.__len__())

        return s

In [12]:
base_path = Path('../input/38cloud-cloud-segmentation-in-satellite-images/38-Cloud_training')
dataset = CloudDataset(base_path/'train_red', 
                    base_path/'train_green', 
                    base_path/'train_blue', 
                    base_path/'train_nir',
                    base_path/'train_gt',transform=transformations)


In [14]:
train_length=int(0.712* len(dataset))

test_length=len(dataset)-train_length
#val_length=500

train_dataset,test_dataset=torch.utils.data.random_split(dataset,(train_length,test_length))
#test_length=test_length-500
#val_set,_=torch.utils.data.random_split(test_dataset,(val_length,test_length))
batch_size= 12

trainloader = DataLoader(train_dataset,
        batch_size=batch_size, shuffle=True,num_workers= 2)
testloader = DataLoader(test_dataset,
        batch_size=batch_size, shuffle=False,num_workers=2)

In [6]:
def Metrics(inputs, targets):
    
    inputs = torch.sigmoid(inputs)
    inputs=inputs.round().int()
    targets=targets.int()
    smooth = 1.0
    inputs = inputs.view(-1)
    targets = targets.view(-1)
    intersection = (inputs & targets).float().sum()
    TP=intersection
    FP = ((1-targets) & inputs).float().sum()
    FN = (targets & (1-inputs)).float().sum()
    TN = ((1-targets) & (1-inputs)).float().sum()
    total = (inputs + targets).float().sum()
    union = total - intersection 
    dice = (2.0 * intersection + smooth) / (total + smooth)
    IoU = ((intersection + smooth)/(union + smooth))
    valid = (targets >= 0)
    acc_sum = (valid * (inputs == targets)).sum()
    valid_sum = valid.sum()
    #acc2=(inputs.argmax(dim=1) == targets.float().mean()
    acc = (float(acc_sum) / (valid_sum + 1e-10))
    precision = (intersection/(FP+intersection+1e-5))
    recall = (intersection/(FN+intersection+1e-5))
    specificity = (TN/(TN+FP+1e-5))
    metrics={'IoU':IoU, 'Dice':dice, 'Pixel_Acc': acc, 'Precision': precision,'Recall': recall, 'Specificity':specificity}
    return metrics


In [7]:
def train(network,criterion, optimizer, trainloader):
    loss_train = 0
    acc_train = 0
    network.train()
    
    for step in tqdm(range(len(trainloader))):

        images , masks = next(iter(trainloader))
        
        # move the images and labels to GPU
        images = images.to(device)
        masks = masks.to(device)
      

        pred = network(images)
        
      
        # clear all the gradients before calculating them
        optimizer.zero_grad()
        #v_pix=val_metrics['Pixel_Acc']
        #val_dice = val_metrics['Dice']
        # find the loss for the current step
        loss_train_step = criterion(pred , masks)
        
        # find accuracy
        acc_train_ = Metrics(pred,masks)
        acc_train_step=acc_train_['IoU']
        # calculate the gradients
        loss_train_step.backward()
        
        # update the parameters
        optimizer.step()
        
        loss_train += loss_train_step.item()
        acc_train += acc_train_step  
            
    loss_train /= len(trainloader)
    acc_train /= len(trainloader)
    #gtrain_dice/= len(testloader)
    #gtrain_pix/= len(testloader)
    #print(pred.max(),pred.min(),masks.max(),masks.min())
    return loss_train, acc_train,acc_train_  
        
def validate(network,criterion, testloader): 
    loss_valid = 0
    acc_valid = 0
    gval_dice=0
    gval_pix=0
    network.eval()  

    for step in tqdm(range(len(testloader))):

        images , masks = next(iter(testloader))
        
        # move the images and labels to GPU
        images = images.to(device)
        masks = masks.to(device)
        
        
        pred = network(images)
        #pred=torch.sigmoid(pred)
        
      
        # clear all the gradients before calculating them
        optimizer.zero_grad()
        
        # find the loss and acc for the current step
        
        loss_valid_step = criterion(pred , masks)
        
        # find accuracy
        val_metrics=Metrics(pred,masks)
        acc_valid_step=val_metrics['IoU']
        #val_pix=val_metrics['Pixel_Acc']
        #val_dice = val_metrics['Dice']
        # calculate the gradients
        #loss_train_step.backward()
        #print(loss_train_step,masks.shape)
        #print(acc_train_)
        # update the parameters
        #optimizer.step()
        
        acc_val = val_metrics
       # acc_valid_step=acc_val['IoU']
        loss_valid += loss_valid_step.item()
        acc_valid += acc_valid_step
        
        
        #print(loss_tvalid_step,masks.shape)
        #print(acc_val)

    loss_valid /= len(testloader)
    acc_valid /= len(testloader)
    #gval_dice/= len(testloader)
    #gval_pix/= len(testloader)
    #print(pred.max(),pred.min(),masks.max(),masks.min())
    return loss_valid, acc_valid,acc_val

In [8]:
!pip install segmentation_models_pytorch 
import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7,     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=4,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

#mask=torch.randn(4,1,384,384)
#target=torch.randn(4,1,384,384)


#z=m(target,mask)
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"

print(device)

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.1.3-py3-none-any.whl (66 kB)
[K     |████████████████████████████████| 66 kB 594 kB/s eta 0:00:01
[?25hCollecting pretrainedmodels==0.7.4
  Downloading pretrainedmodels-0.7.4.tar.gz (58 kB)
[K     |████████████████████████████████| 58 kB 1.8 MB/s eta 0:00:01
[?25hCollecting efficientnet-pytorch==0.6.3
  Downloading efficientnet_pytorch-0.6.3.tar.gz (16 kB)
Collecting timm==0.3.2
  Downloading timm-0.3.2-py3-none-any.whl (244 kB)
[K     |████████████████████████████████| 244 kB 3.0 MB/s eta 0:00:01
Building wheels for collected packages: efficientnet-pytorch, pretrainedmodels
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.6.3-py3-none-any.whl size=12419 sha256=a827aa5fa06f7493b9fdea9f8eed3c60d6b5986f62595ef7d3ad2a9ecbcd5ebd
  Stored in directory: /root/.cache/pip/wheels/90/6b/0c/f0ad36d00310e65390b0d4

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

cuda


In [15]:
wandb.init(name='Clouds', 
           project='UNetResnet+WBCE+Augs',
           notes='RGBNIR', 
           #tags=['Replay-Attack','Cyclic_LR'],
           entity='creganstark')

# WandB Configurations (optional)        
wandb.config.lr = 7e-3
#model=model.float()
#model=model.double()
#model = model.to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.3053))
#criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(),lr  =  wandb.config.lr)# ,
                      #momentum     = 0.9,
                      #nesterov     = True,
                      #weight_decay = 5e-4)

#scheduler = optim.lr_scheduler.CyclicLR(optimizer,base_lr=wandb.config.lr,max_lr=1e-3 ,step_size_up=2000)
#scheduler = optim.lr_scheduler.CyclicLR(optimizer,base_lr=wandb.config.lr,max_lr=1e-3 ,step_size_up=2000)
  

# Log the network weight histograms (optional)
wandb.watch(model)

num_epochs = 20
start_time = time.time()
prev_acc=0
for epoch in range(1, num_epochs+1):
    
    loss_train, acc_train,metric_train = train(model, criterion, optimizer, trainloader)
    loss_valid, acc_valid,metric_val = validate(model, criterion, testloader)
    
    print('Epoch: {}  Train Loss: {:.4f}  Train IoU: {:.4f}  Valid Loss: {:.4f}  Valid IoU: {:.4f}'.
          format(epoch, loss_train, acc_train, loss_valid, acc_valid))
    print('Train Metrics',metric_train)
    print('Val. Metrics',metric_val)

    # Log the loss and accuracy values at the end of each epoch
    wandb.log({
        "Epoch": epoch,
        "Train IoU": metric_train['IoU'],
        "Train Dice": metric_train['Dice'],
        "Train Pixel Acc": metric_train['Pixel_Acc'],
        "Train Precision":metric_train["Precision"],
        "Train Recall": metric_train['Recall'],
        "Train Specificity":metric_train['Specificity'],
        "Train Loss": loss_train,
        "Val IoU": metric_val['IoU'],
        "Val Dice ": metric_val['Dice'],
        "Val Pixel Acc": metric_val['Pixel_Acc'],
        "Val Precision":metric_val["Precision"],
        "Val Recall": metric_val['Recall'],
        "Val Specificity":metric_val['Specificity'],
        "Val Loss" : loss_valid
       })
    if acc_valid>prev_acc:
      prev_acc=acc_valid
      paths = "model"+str(acc_valid)+".pt"
      print('Saving Model')
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'val_loss': loss_valid,
          'val_acc':acc_valid,
          'train_acc':acc_train,
          'loss_acc':loss_train,
          }, str(paths))

print("Time Elapsed : {:.4f}s".format(time.time() - start_time))

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  0%|          | 0/499 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

Epoch: 1  Train Loss: 0.5199  Train IoU: 0.5891  Valid Loss: 0.2847  Valid IoU: 0.7792
Saving Model


  0%|          | 0/499 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

Epoch: 2  Train Loss: 0.4577  Train IoU: 0.6107  Valid Loss: 0.4382  Valid IoU: 0.7997
Saving Model


  0%|          | 0/499 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

Epoch: 3  Train Loss: 0.4290  Train IoU: 0.6239  Valid Loss: 0.3488  Valid IoU: 0.6702


  0%|          | 0/499 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

Epoch: 4  Train Loss: 0.4179  Train IoU: 0.6312  Valid Loss: 0.1707  Valid IoU: 0.8623
Saving Model


  0%|          | 0/499 [00:00<?, ?it/s]

  0%|          | 0/202 [00:00<?, ?it/s]

Epoch: 5  Train Loss: 0.4200  Train IoU: 0.6175  Valid Loss: 0.2196  Valid IoU: 0.8481


  0%|          | 0/499 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fae6614f8c0>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
    self._shutdown_workers()
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/opt/conda/lib/python3.7/multiprocessing/process.py", line 140, in join
    res = self._popen.wait(timeout)
  File "/opt/conda/lib/python3.7/multiprocessing/popen_fork.py", line 45, in wait
    if not wait([self.sentinel], timeout):
  File "/opt/conda/lib/python3.7/multiprocessing/connection.py", line 921, in wait
    ready = selector.select(timeout)
  File "/opt/conda/lib/python3.7/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt: 


KeyboardInterrupt: 

In [23]:
len(val_set)

350

In [None]:
wandb.init(name='Clouds', 
           project='UNetResnet+WBCE+Augs',
           notes='RGBNIR', 
           #tags=['Replay-Attack','Cyclic_LR'],
           entity='creganstark')

# WandB Configurations (optional)        
wandb.config.lr = 9e-3
#model=model.float()
#model=model.double()
#model = model.to(device)

criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(2.3053))
#criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(),lr  =  wandb.config.lr)# ,
                      #momentum     = 0.9,
                      #nesterov     = True,
                      #weight_decay = 5e-4)

#scheduler = optim.lr_scheduler.CyclicLR(optimizer,base_lr=wandb.config.lr,max_lr=1e-3 ,step_size_up=2000)
#scheduler = optim.lr_scheduler.CyclicLR(optimizer,base_lr=wandb.config.lr,max_lr=1e-3 ,step_size_up=2000)
  

# Log the network weight histograms (optional)
#wandb.watch(model)

num_epochs = 5
start_time = time.time()
prev_acc=0
for epoch in range(1, num_epochs+1):
    
    loss_train, acc_train,metric_train = train(model, criterion, optimizer, trainloader)
    loss_valid, acc_valid,metric_val = validate(model, criterion, testloader)
    
    print('Epoch: {}  Train Loss: {:.4f}  Train IoU: {:.4f}  Valid Loss: {:.4f}  Valid IoU: {:.4f}'.
          format(epoch, loss_train, acc_train, loss_valid, acc_valid))
    print('Train Metrics',metric_train)
    print('Val. Metrics',metric_val)

    # Log the loss and accuracy values at the end of each epoch
    wandb.log({
        "Epoch": epoch,
        "Train IoU": metric_train['IoU'],
        "Train Dice": metric_train['Dice'],
        "Train Pixel Acc": metric_train['Pixel_Acc'],
        "Train Precision":metric_train["Precision"],
        "Train Recall": metric_train['Recall'],
        "Train Specificity":metric_train['Specificity'],
        "Train Loss": loss_train,
        "Val IoU": metric_val['IoU'],
        "Val Dice ": metric_val['Dice'],
        "Val Pixel Acc": metric_val['Pixel_Acc'],
        "Val Precision":metric_val["Precision"],
        "Val Recall": metric_val['Recall'],
        "Val Specificity":metric_val['Specificity'],
        "Val Loss" : loss_valid
       })
    if acc_valid>prev_acc:
      prev_acc=acc_valid
      paths = "model"+str(acc_valid)+".pt"
      print('Saving Model')
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'val_loss': loss_valid,
          'val_acc':acc_valid,
          'train_acc':acc_train,
          'loss_acc':loss_train,
          }, str(paths))

print("Time Elapsed : {:.4f}s".format(time.time() - start_time))

[34m[1mwandb[0m: [32m[41mERROR[0m Problem finishing run
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_run.py", line 1485, in _atexit_cleanup
    self._on_finish()
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/wandb_run.py", line 1642, in _on_finish
    self._backend.interface.publish_telemetry(self._telemetry_obj)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/interface/interface.py", line 226, in publish_telemetry
    self._publish(rec)
  File "/opt/conda/lib/python3.7/site-packages/wandb/sdk/interface/interface.py", line 518, in _publish
    raise Exception("The wandb backend process has shutdown")
Exception: The wandb backend process has shutdown


[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


  0%|          | 0/490 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: 1  Train Loss: 0.4553  Train IoU: 0.6012  Valid Loss: 0.6552  Valid IoU: 0.6537
Train Metrics {'IoU': tensor(0.8930, device='cuda:0'), 'Dice': tensor(0.9435, device='cuda:0'), 'Pixel_Acc': tensor(0.9444, device='cuda:0'), 'Precision': tensor(0.9625, device='cuda:0'), 'Recall': tensor(0.9252, device='cuda:0'), 'Specificity': tensor(0.9637, device='cuda:0')}
Val. Metrics {'IoU': tensor(0.5229, device='cuda:0'), 'Dice': tensor(0.6867, device='cuda:0'), 'Pixel_Acc': tensor(0.7753, device='cuda:0'), 'Precision': tensor(0.6337, device='cuda:0'), 'Recall': tensor(0.7494, device='cuda:0'), 'Specificity': tensor(0.7880, device='cuda:0')}
Saving Model


  0%|          | 0/490 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch: 2  Train Loss: 0.4160  Train IoU: 0.6378  Valid Loss: 0.5492  Valid IoU: 0.6828
Train Metrics {'IoU': tensor(0.6091, device='cuda:0'), 'Dice': tensor(0.7571, device='cuda:0'), 'Pixel_Acc': tensor(0.8013, device='cuda:0'), 'Precision': tensor(0.6112, device='cuda:0'), 'Recall': tensor(0.9943, device='cuda:0'), 'Specificity': tensor(0.7139, device='cuda:0')}
Val. Metrics {'IoU': tensor(0.5959, device='cuda:0'), 'Dice': tensor(0.7468, device='cuda:0'), 'Pixel_Acc': tensor(0.7698, device='cuda:0'), 'Precision': tensor(0.5994, device='cuda:0'), 'Recall': tensor(0.9901, device='cuda:0'), 'Specificity': tensor(0.6549, device='cuda:0')}
Saving Model


  0%|          | 0/490 [00:00<?, ?it/s]

In [16]:
paths='5epochsunet100.pt'
torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'val_loss': loss_valid,
          'val_acc':acc_valid,
          'train_acc':acc_train,
          'loss_acc':loss_train,
          }, str(paths))

In [17]:
filename = 'modeltensor(0.8623, device='+'-cuda'+':'+'0-'+').pt'
filename=filename.replace('-',"'")

In [18]:
#filename = 'modeltensor(0.8413, device='cuda:0').pt
def load_checkpoint(model,filepath):
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    #for parameter in model.parameters():
        #parameter.requires_grad = True
    return model
model=load_checkpoint(model,filename)

In [20]:
train_length=int(0.7* len(dataset))

test_length=len(dataset)-train_length
val_length=350

train_dataset,test_dataset=torch.utils.data.random_split(dataset,(train_length,test_length))
test_length=test_length-val_length
val_set,_=torch.utils.data.random_split(test_dataset,(val_length,test_length))
batch_size= 12

trainloader = DataLoader(train_dataset,
        batch_size=batch_size, shuffle=True,num_workers= 2)
testloader = DataLoader(val_set,
        batch_size=batch_size, shuffle=False,num_workers=2)