# SwinIR: XLA Inference Tests

In [1]:
import torch
torch.cuda.is_available()

False

In [2]:
torch.__version__

'1.9.0+cu102'

In [3]:
import torch_xla.core.xla_model as xm

In [4]:
device = xm.xla_device()
devices = xm.get_xla_supported_devices(xm.xla_device_hw(device))
devices

2022-07-18 16:46:07.322339: E tensorflow/core/framework/op_kernel.cc:1623] OpKernel ('op: "TPURoundRobin" device_type: "CPU"') for unknown op: TPURoundRobin
2022-07-18 16:46:07.322394: E tensorflow/core/framework/op_kernel.cc:1623] OpKernel ('op: "TpuHandleToProtoKey" device_type: "CPU"') for unknown op: TpuHandleToProtoKey


['xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7', 'xla:8']

In [5]:
tile = 256
torch.set_grad_enabled(False)
torch.set_num_threads(1)

## Model Loading

This was extracted from the `real_sr` task of `main_test_swinir.py`.

In [6]:
from models.network_swinir import SwinIR as net

In [7]:
large_model = False
upscale = 4

In [8]:
if not large_model:
    # use 'nearest+conv' to avoid block artifacts
    model = net(upscale=upscale, in_chans=3, img_size=64, window_size=8,
                img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
                mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
else:
    # larger model size; use '3conv' to save parameters and memory; use ema for GAN training
    model = net(upscale=upscale, in_chans=3, img_size=64, window_size=8,
                img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
                num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
                mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
param_key_g = 'params_ema'

In [9]:
# Weights for the "small" model, we need to download the large ones.
model_path = 'model_zoo/swinir/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth'

In [10]:
pretrained_model = torch.load(model_path)
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)

<All keys matched successfully>

Set the model to evaluation mode

In [11]:
model = model.eval().to(device)

## Preparation

We use OpenCV to load the image, just as the sample scripts do. Note that OpenCV uses BGR format, conversion to RGB is performed later.

In [12]:
import cv2
import numpy as np

from matplotlib import pyplot as plt

## Test single image inference on a single XLA device

In [13]:
# image_path = 'testsets/dalle-mini-samples/giraffe.png'
# img_lq = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
# img_lq = img_lq[:, :, [2, 1, 0]]
# plt.imshow(img_lq)

In [14]:
# img_lq.shape

In [15]:
# img_lq = np.transpose(img_lq, (2, 0, 1))  # HCW to CHW
# img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

In [16]:
# img_lq.shape

In [17]:
# example_input = img_lq
# example_input = example_input.to(device)

The first time it took 3 min to compile. Then it's much faster but still slow: 13s per inference.

In [18]:
# %%time
# out = model(example_input)

In [19]:
# out.shape

In [20]:
# out_img = out.squeeze().cpu().numpy()
# out_img = out_img.clip(0, 1)
# out_img = np.uint(255 * out_img)
# out_img = np.transpose(out_img, [1, 2, 0])

In [21]:
# out_img.shape

In [22]:
# plt.imshow(out_img)

## Image Batch (single XLA device)

In [23]:
from pathlib import Path

In [24]:
images = Path('testsets/dalle-mini-samples/')

In [25]:
def read_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = img[:, :, [2, 1, 0]]
    img = np.transpose(img, (2, 0, 1))  # HCW to CHW
    img = torch.from_numpy(img).float().unsqueeze(0)
    return img

We have 28 images. Let's cycle through them to create whatever batch size we want to test.

In my 3090, I found that it works with a batch size of 30, but a bs of 32 caused it to fail with a weird error. Using `nvtop`, it looked like those two additional images would fit, but I don't know.

On TPU v2-8 I'll start with a batch size of 10 as a single TPU core has 8 GB of memory.

In [26]:
import itertools

In [27]:
bs = 2

In [28]:
all_images = None
for image_path, _ in zip(itertools.cycle(images.iterdir()), range(bs)):
    print(image_path)
    img = read_image(str(image_path))
    all_images = img if all_images is None else torch.vstack((all_images, img))

testsets/dalle-mini-samples/avocado.png
testsets/dalle-mini-samples/sunset.png


In [29]:
all_images.shape

torch.Size([2, 3, 256, 256])

With a batch size of 2, the first time it took ~3min to compile. Then ~16s per inference, which is slow because the batch size is so small.

With a batch size of 3, the first time it took ~8min and then ~23.6s per inference.

In [35]:
%%time
all_images = all_images.to(device)
out_images = model(all_images)
out_images = out_images.cpu().numpy()

CPU times: user 42.8 s, sys: 1.77 s, total: 44.6 s
Wall time: 15.7 s


In [31]:
out_images.shape

(2, 3, 1024, 1024)

In [32]:
torch.cuda.empty_cache()

In [33]:
10.5 / bs

5.25

------