Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Tokenizer(
const val NAME = "Tokenizer"
}

override fun load(
override fun loadModule(
tokenizerSource: String,
promise: Promise,
) {
Expand Down
6 changes: 3 additions & 3 deletions ios/RnExecutorch/Tokenizer.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ @implementation Tokenizer {

RCT_EXPORT_MODULE()

- (void)load:(NSString *)tokenizerSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
- (void)loadModule:(NSString *)tokenizerSource
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
tokenizer =
[[HuggingFaceTokenizer alloc] initWithTokenizerPath:tokenizerSource];
Expand Down
6 changes: 3 additions & 3 deletions src/controllers/OCRController.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { symbols } from '../constants/ocr/symbols';
import { ETError, getError } from '../Error';
import { _OCRModule } from '../native/RnExecutorchModules';
import { OCRNativeModule } from '../native/RnExecutorchModules';
import { ResourceSource } from '../types/common';
import { OCRLanguage } from '../types/ocr';
import { ResourceFetcher } from '../utils/ResourceFetcher';

export class OCRController {
private nativeModule: _OCRModule;
private nativeModule: typeof OCRNativeModule;
public isReady: boolean = false;
public isGenerating: boolean = false;
public error: string | null = null;
Expand All @@ -21,7 +21,7 @@ export class OCRController {
isGeneratingCallback = (_isGenerating: boolean) => {},
errorCallback = (_error: string) => {},
}) {
this.nativeModule = new _OCRModule();
this.nativeModule = OCRNativeModule;
this.modelDownloadProgressCallback = modelDownloadProgressCallback;
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
Expand Down
32 changes: 19 additions & 13 deletions src/controllers/SpeechToTextController.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { _SpeechToTextModule } from '../native/RnExecutorchModules';
import { SpeechToTextNativeModule } from '../native/RnExecutorchModules';
import * as FileSystem from 'expo-file-system';
import { ResourceFetcher } from '../utils/ResourceFetcher';
import { ResourceSource } from '../types/common';
Expand Down Expand Up @@ -36,7 +36,7 @@ const longCommonInfPref = (seq1: number[], seq2: number[]) => {
};

export class SpeechToTextController {
private nativeModule: _SpeechToTextModule;
private speechToTextNativeModule: typeof SpeechToTextNativeModule;

private overlapSeconds!: number;
private windowSize!: number;
Expand Down Expand Up @@ -91,7 +91,7 @@ export class SpeechToTextController {
isGeneratingCallback?.(isGenerating);
};
this.onErrorCallback = onErrorCallback;
this.nativeModule = new _SpeechToTextModule();
this.speechToTextNativeModule = SpeechToTextNativeModule;
this.configureStreaming(
overlapSeconds,
windowSize,
Expand Down Expand Up @@ -121,18 +121,19 @@ export class SpeechToTextController {

try {
this.tokenMapping = await this.fetchTokenizer(tokenizerSource);
[encoderSource, decoderSource] = await ResourceFetcher.fetchMultipleResources(
this.modelDownloadProgessCallback,
encoderSource || this.config.sources.encoder,
decoderSource || this.config.sources.decoder
);
[encoderSource, decoderSource] =
await ResourceFetcher.fetchMultipleResources(
this.modelDownloadProgessCallback,
encoderSource || this.config.sources.encoder,
decoderSource || this.config.sources.decoder
);
} catch (e) {
this.onErrorCallback?.(e);
return;
}

try {
await this.nativeModule.loadModule(modelName, [
await this.speechToTextNativeModule.loadModule(modelName, [
encoderSource!,
decoderSource!,
]);
Expand Down Expand Up @@ -220,14 +221,14 @@ export class SpeechToTextController {
let finalSeq: number[] = [];
let seq = [lastToken];
try {
await this.nativeModule.encode(this.chunks!.at(chunkId)!);
await this.speechToTextNativeModule.encode(this.chunks!.at(chunkId)!);
} catch (error) {
this.onErrorCallback?.(`Encode ${error}`);
return '';
}
while (lastToken !== this.config.tokenizer.eos) {
try {
lastToken = await this.nativeModule.decode(seq);
lastToken = await this.decode(seq);
} catch (error) {
this.onErrorCallback?.(`Decode ${error}`);
return '';
Expand Down Expand Up @@ -302,10 +303,15 @@ export class SpeechToTextController {
}

public async encode(waveform: number[]) {
return await this.nativeModule.encode(waveform);
return await this.speechToTextNativeModule.encode(waveform);
}

public async decode(seq: number[], encodings?: number[]) {
return await this.nativeModule.decode(seq, encodings);
/*
CAUTION: When you pass empty decoding array, it uses the cached encodings.
For instance, when you call .encode() for the first time, it internally caches
the encoding results.
*/
return await this.speechToTextNativeModule.decode(seq, encodings || []);
}
}
10 changes: 5 additions & 5 deletions src/controllers/VerticalOCRController.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { symbols } from '../constants/ocr/symbols';
import { ETError, getError } from '../Error';
import { _VerticalOCRModule } from '../native/RnExecutorchModules';
import { VerticalOCRNativeModule } from '../native/RnExecutorchModules';
import { ResourceSource } from '../types/common';
import { OCRLanguage } from '../types/ocr';
import { ResourceFetcher } from '../utils/ResourceFetcher';

export class VerticalOCRController {
private nativeModule: _VerticalOCRModule;
private ocrNativeModule: typeof VerticalOCRNativeModule;
public isReady: boolean = false;
public isGenerating: boolean = false;
public error: string | null = null;
Expand All @@ -21,7 +21,7 @@ export class VerticalOCRController {
isGeneratingCallback = (_isGenerating: boolean) => {},
errorCallback = (_error: string) => {},
}) {
this.nativeModule = new _VerticalOCRModule();
this.ocrNativeModule = VerticalOCRNativeModule;
this.modelDownloadProgressCallback = modelDownloadProgressCallback;
this.isReadyCallback = isReadyCallback;
this.isGeneratingCallback = isGeneratingCallback;
Expand Down Expand Up @@ -63,7 +63,7 @@ export class VerticalOCRController {
: recognizerSources.recognizerLarge
);

await this.nativeModule.loadModule(
await this.ocrNativeModule.loadModule(
paths[0]!,
paths[1]!,
paths[2]!,
Expand Down Expand Up @@ -93,7 +93,7 @@ export class VerticalOCRController {
try {
this.isGenerating = true;
this.isGeneratingCallback(this.isGenerating);
return await this.nativeModule.forward(input);
return await this.ocrNativeModule.forward(input);
} catch (e) {
throw new Error(getError(e));
} finally {
Expand Down
32 changes: 5 additions & 27 deletions src/hooks/computer_vision/useClassification.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
import { useState } from 'react';
import { _ClassificationModule } from '../../native/RnExecutorchModules';
import { ClassificationModule } from '../../modules/computer_vision/ClassificationModule';
import { ResourceSource } from '../../types/common';
import { useModule } from '../useModule';

interface Props {
modelSource: string | number;
}

export const useClassification = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (input: string) => Promise<{ [category: string]: number }>;
} => {
const [module, _] = useState(() => new _ClassificationModule());
const {
error,
isReady,
isGenerating,
downloadProgress,
forwardImage: forward,
} = useModule({
modelSource,
module,
});

return { error, isReady, isGenerating, downloadProgress, forward };
};
}: {
modelSource: ResourceSource;
}) => useModule({ module: ClassificationModule, loadArgs: [modelSource] });
65 changes: 3 additions & 62 deletions src/hooks/computer_vision/useImageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,68 +1,9 @@
import { useState } from 'react';
import { _ImageSegmentationModule } from '../../native/RnExecutorchModules';
import { ETError, getError } from '../../Error';
import { useModule } from '../useModule';
import { DeeplabLabel } from '../../types/image_segmentation';
import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegmentationModule';

interface Props {
modelSource: string | number;
}

export const useImageSegmentation = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => Promise<{ [key in DeeplabLabel]?: number[] }>;
} => {
const [module, _] = useState(() => new _ImageSegmentationModule());
const [isGenerating, setIsGenerating] = useState(false);
const { error, isReady, downloadProgress } = useModule({
modelSource,
module,
});

const forward = async (
input: string,
classesOfInterest?: DeeplabLabel[],
resize?: boolean
) => {
if (!isReady) {
throw new Error(getError(ETError.ModuleNotLoaded));
}
if (isGenerating) {
throw new Error(getError(ETError.ModelGenerating));
}

try {
setIsGenerating(true);
const stringDict = await module.forward(
input,
(classesOfInterest || []).map((label) => DeeplabLabel[label]),
resize || false
);

let enumDict: { [key in DeeplabLabel]?: number[] } = {};

for (const key in stringDict) {
if (key in DeeplabLabel) {
const enumKey = DeeplabLabel[key as keyof typeof DeeplabLabel];
enumDict[enumKey] = stringDict[key];
}
}
return enumDict;
} catch (e) {
throw new Error(getError(e));
} finally {
setIsGenerating(false);
}
};

return { error, isReady, isGenerating, downloadProgress, forward };
};
export const useImageSegmentation = ({ modelSource }: Props) =>
useModule({ module: ImageSegmentationModule, loadArgs: [modelSource] });
32 changes: 5 additions & 27 deletions src/hooks/computer_vision/useObjectDetection.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,10 @@
import { useState } from 'react';
import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
import { ResourceSource } from '../../types/common';
import { useModule } from '../useModule';
import { Detection } from '../../types/object_detection';
import { ObjectDetectionModule } from '../../modules/computer_vision/ObjectDetectionModule';

interface Props {
modelSource: string | number;
modelSource: ResourceSource;
}

export const useObjectDetection = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (input: string) => Promise<Detection[]>;
} => {
const [module, _] = useState(() => new _ObjectDetectionModule());
const {
error,
isReady,
isGenerating,
downloadProgress,
forwardImage: forward,
} = useModule({
modelSource,
module,
});

return { error, isReady, isGenerating, downloadProgress, forward };
};
export const useObjectDetection = ({ modelSource }: Props) =>
useModule({ module: ObjectDetectionModule, loadArgs: [modelSource] });
31 changes: 5 additions & 26 deletions src/hooks/computer_vision/useStyleTransfer.ts
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
import { useState } from 'react';
import { _StyleTransferModule } from '../../native/RnExecutorchModules';
import { ResourceSource } from '../../types/common';
import { useModule } from '../useModule';
import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule';

interface Props {
modelSource: string | number;
modelSource: ResourceSource;
}

export const useStyleTransfer = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
downloadProgress: number;
forward: (input: string) => Promise<string>;
} => {
const [module, _] = useState(() => new _StyleTransferModule());
const {
error,
isReady,
isGenerating,
downloadProgress,
forwardImage: forward,
} = useModule({
modelSource,
module,
});

return { error, isReady, isGenerating, downloadProgress, forward };
};
export const useStyleTransfer = ({ modelSource }: Props) =>
useModule({ module: StyleTransferModule, loadArgs: [modelSource] });
Loading