In [None]:
%%capture
import sys
!{sys.executable} -m pip install SimpleITK==2.2.1 cassandra-driver==3.27.0 diskcache==4.1.0

# The LunaDataSet Class

In this notebook, I work through the implementation of the `LunaDataSet` class, a
subclass of the `DataSet` class provided by Pytorch. 
It is implemented in the `p2ch10/dsets.py` module.

A subclass must implement:

- `__len__` : The number of all candidate nodules in the dataset for which image data is available. (The full list of candidates is available in the `candidates.csv` file.)

and

- `__getitem__` methods: Returns a tuple with information about the nodule, as well as the cropped image data (as a tensor) centered on the nodule.

In [None]:
import copy

class LunaDataset(Dataset):
    def __init__(self,
                 val_stride=0,
                 isValSet_bool=None,
                 series_uid=None,
            ):
        # compile `candidates.csv` and `annotations.csv` into a list of nodules
        # as implemented in dsets.py - make a copy so the original is never modified
        self.candidateInfo_list = copy.copy(getCandidateInfoList())

        # subset to the candidate list to the requested uid, if it has been provided
        if series_uid:
            self.candidateInfo_list = [
                x for x in self.candidateInfo_list if x.series_uid == series_uid
            ]

        # subset to validation dataset by slicing the candidate list
        if isValSet_bool:
            assert val_stride > 0, val_stride
            self.candidateInfo_list = self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list
        # or remove the validation data instead
        elif val_stride > 0:
            del self.candidateInfo_list[::val_stride]
            assert self.candidateInfo_list

        # {!r} is equivalent to {repr()}, e.g. returning the representation of the object
        log.info("{!r}: {} {} samples".format(
            self,
            len(self.candidateInfo_list),
            "validation" if isValSet_bool else "training",
        ))

    def __len__(self):
        return len(self.candidateInfo_list)

    def __getitem__(self, ndx):
        candidateInfo_tup = self.candidateInfo_list[ndx]
        width_irc = (32, 48, 48)  # width of the zoomed-in IRC array

        candidate_a, center_irc = getCtRawCandidate(
            candidateInfo_tup.series_uid,
            candidateInfo_tup.center_xyz,
            width_irc,
        )
        # convert the candidate array into a tensor
        candidate_t = torch.from_numpy(candidate_a)
        candidate_t = candidate_t.to(torch.float32)
        candidate_t = candidate_t.unsqueeze(0)  # add channel dimension

        # the class of the candidate nodule, as a 1D tensor of length 2
        # as expected by nn.CrossEntropyLoss
        pos_t = torch.tensor([
                not candidateInfo_tup.isNodule_bool,
                candidateInfo_tup.isNodule_bool
            ],
            dtype=torch.long,
        )

        return (
            candidate_t,  # image data (tensor)
            pos_t,  # class (tensor)
            candidateInfo_tup.series_uid,  # unique id (string)
            torch.tensor(center_irc),  # center coordinates (tensor)
        )

## Caching

The `__getitem__` method of a `LunaDataSet` returns one specific nodule from a Ct scan
at a time. That would require reading the large image from disk _every time_, slowing
down the workflow.

We can speed up the operation by caching the image, using different approaches.

In [None]:
@functools.lru_cache(1, typed=True)
def getCt(series_uid):
    return Ct(series_uid)

@raw_cache.memoize(typed=True)
def getCtRawCandidate(series_uid, center_xyz, width_irc):
    ct = getCt(series_uid)
    ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
    return ct_chunk, center_irc


- The `getCt` helper function retrieves the full image data for a single series_uid and
caches it _in memory_.
    - Only a single image is cached in memory at the same time.
- The `getCtRawCandidate` helper function retrieves the region around a single nodule
and caches it _on disk_.
    - the `@diskcache.memoize` decorator is equivalent to `functools.lru_cache` but 
      uses an on-disk cache object (called `raw_cache` in this case).