Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class OpenAiAudioSpeechProperties extends OpenAiParentProperties {

public static final String DEFAULT_SPEECH_MODEL = OpenAiAudioApi.TtsModel.GPT_4_O_MINI_TTS.getValue();

private static final Float SPEED = 1.0f;
private static final Double SPEED = 1.0;

private static final String VOICE = OpenAiAudioApi.SpeechRequest.Voice.ALLOY.getValue();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import reactor.core.publisher.Flux;

import org.springframework.ai.audio.tts.Speech;
import org.springframework.ai.audio.tts.StreamingTextToSpeechModel;
import org.springframework.ai.audio.tts.TextToSpeechModel;
import org.springframework.ai.audio.tts.TextToSpeechPrompt;
import org.springframework.ai.audio.tts.TextToSpeechResponse;
Expand All @@ -35,12 +34,11 @@
import org.springframework.util.MultiValueMap;

/**
* Implementation of the {@link TextToSpeechModel} and {@link StreamingTextToSpeechModel}
* interfaces
* Implementation of the {@link TextToSpeechModel} interface for ElevenLabs TTS API.
*
* @author Alexandros Pappas
*/
public class ElevenLabsTextToSpeechModel implements TextToSpeechModel, StreamingTextToSpeechModel {
public class ElevenLabsTextToSpeechModel implements TextToSpeechModel {

private final Logger logger = LoggerFactory.getLogger(getClass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@

package org.springframework.ai.openai;

import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.audio.tts.Speech;
import org.springframework.ai.audio.tts.TextToSpeechModel;
import org.springframework.ai.audio.tts.TextToSpeechOptions;
import org.springframework.ai.audio.tts.TextToSpeechPrompt;
import org.springframework.ai.audio.tts.TextToSpeechResponse;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.openai.api.OpenAiAudioApi;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.audio.speech.Speech;
import org.springframework.ai.openai.audio.speech.SpeechModel;
import org.springframework.ai.openai.audio.speech.SpeechPrompt;
import org.springframework.ai.openai.audio.speech.SpeechResponse;
import org.springframework.ai.openai.audio.speech.StreamingSpeechModel;
import org.springframework.ai.openai.metadata.audio.OpenAiAudioSpeechResponseMetadata;
import org.springframework.ai.openai.metadata.support.OpenAiResponseHeaderExtractor;
import org.springframework.ai.retry.RetryUtils;
Expand All @@ -46,13 +48,13 @@
* @see OpenAiAudioApi
* @since 1.0.0-M1
*/
public class OpenAiAudioSpeechModel implements SpeechModel, StreamingSpeechModel {
public class OpenAiAudioSpeechModel implements TextToSpeechModel {

/**
* The speed of the default voice synthesis.
* @see OpenAiAudioSpeechOptions
*/
private static final Float SPEED = 1.0f;
private static final Double SPEED = 1.0;

private final Logger logger = LoggerFactory.getLogger(getClass());

Expand Down Expand Up @@ -118,14 +120,14 @@ public OpenAiAudioSpeechModel(OpenAiAudioApi audioApi, OpenAiAudioSpeechOptions

@Override
public byte[] call(String text) {
SpeechPrompt speechRequest = new SpeechPrompt(text);
return call(speechRequest).getResult().getOutput();
TextToSpeechPrompt prompt = new TextToSpeechPrompt(text);
return call(prompt).getResult().getOutput();
}

@Override
public SpeechResponse call(SpeechPrompt speechPrompt) {
public TextToSpeechResponse call(TextToSpeechPrompt prompt) {

OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(prompt);

ResponseEntity<byte[]> speechEntity = this.retryTemplate
.execute(ctx -> this.audioApi.createSpeech(speechRequest));
Expand All @@ -134,48 +136,42 @@ public SpeechResponse call(SpeechPrompt speechPrompt) {

if (speech == null) {
logger.warn("No speech response returned for speechRequest: {}", speechRequest);
return new SpeechResponse(new Speech(new byte[0]));
return new TextToSpeechResponse(List.of(new Speech(new byte[0])));
}

RateLimit rateLimits = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(speechEntity);

return new SpeechResponse(new Speech(speech), new OpenAiAudioSpeechResponseMetadata(rateLimits));
return new TextToSpeechResponse(List.of(new Speech(speech)), new OpenAiAudioSpeechResponseMetadata(rateLimits));
}

/**
* Streams the audio response for the given speech prompt.
* @param speechPrompt The speech prompt containing the text and options for speech
* @param prompt The speech prompt containing the text and options for speech
* synthesis.
* @return A Flux of SpeechResponse objects containing the streamed audio and
* @return A Flux of TextToSpeechResponse objects containing the streamed audio and
* metadata.
*/
@Override
public Flux<SpeechResponse> stream(SpeechPrompt speechPrompt) {
public Flux<TextToSpeechResponse> stream(TextToSpeechPrompt prompt) {

OpenAiAudioApi.SpeechRequest speechRequest = createRequest(speechPrompt);
OpenAiAudioApi.SpeechRequest speechRequest = createRequest(prompt);

Flux<ResponseEntity<byte[]>> speechEntity = this.retryTemplate
.execute(ctx -> this.audioApi.stream(speechRequest));

return speechEntity.map(entity -> new SpeechResponse(new Speech(entity.getBody()),
return speechEntity.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())),
new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity))));
}

private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request) {
OpenAiAudioSpeechOptions options = this.defaultOptions;

if (request.getOptions() != null) {
if (request.getOptions() instanceof OpenAiAudioSpeechOptions runtimeOptions) {
options = this.merge(runtimeOptions, options);
}
else {
throw new IllegalArgumentException("Prompt options are not of type SpeechOptions: "
+ request.getOptions().getClass().getSimpleName());
}
}
private OpenAiAudioApi.SpeechRequest createRequest(TextToSpeechPrompt prompt) {
OpenAiAudioSpeechOptions runtimeOptions = (prompt
.getOptions() instanceof OpenAiAudioSpeechOptions openAiAudioSpeechOptions) ? openAiAudioSpeechOptions
: null;
OpenAiAudioSpeechOptions options = (runtimeOptions != null) ? this.merge(runtimeOptions, this.defaultOptions)
: this.defaultOptions;

String input = StringUtils.hasText(options.getInput()) ? options.getInput()
: request.getInstructions().getText();
: prompt.getInstructions().getText();

OpenAiAudioApi.SpeechRequest.Builder requestBuilder = OpenAiAudioApi.SpeechRequest.builder()
.model(options.getModel())
Expand All @@ -187,6 +183,11 @@ private OpenAiAudioApi.SpeechRequest createRequest(SpeechPrompt request) {
return requestBuilder.build();
}

@Override
public TextToSpeechOptions getDefaultOptions() {
return this.defaultOptions;
}

private OpenAiAudioSpeechOptions merge(OpenAiAudioSpeechOptions source, OpenAiAudioSpeechOptions target) {
OpenAiAudioSpeechOptions.Builder mergedBuilder = OpenAiAudioSpeechOptions.builder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.audio.tts.TextToSpeechOptions;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.AudioResponseFormat;
import org.springframework.ai.openai.api.OpenAiAudioApi.SpeechRequest.Voice;

Expand All @@ -33,7 +33,7 @@
* @since 1.0.0-M1
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
public class OpenAiAudioSpeechOptions implements ModelOptions {
public class OpenAiAudioSpeechOptions implements TextToSpeechOptions {

/**
* ID of the model to use for generating the audio. For OpenAI's TTS API, use one of
Expand Down Expand Up @@ -67,7 +67,7 @@ public class OpenAiAudioSpeechOptions implements ModelOptions {
* 4.0 (fastest). Defaults to 1 (normal)
*/
@JsonProperty("speed")
private Float speed;
private Double speed;

public static Builder builder() {
return new Builder();
Expand Down Expand Up @@ -109,14 +109,34 @@ public void setResponseFormat(AudioResponseFormat responseFormat) {
this.responseFormat = responseFormat;
}

public Float getSpeed() {
@Override
public Double getSpeed() {
return this.speed;
}

public void setSpeed(Float speed) {
public void setSpeed(Double speed) {
this.speed = speed;
}

// TextToSpeechOptions interface methods

@Override
public String getFormat() {
return (this.responseFormat != null) ? this.responseFormat.name().toLowerCase() : null;
}

@Override
@SuppressWarnings("unchecked")
public OpenAiAudioSpeechOptions copy() {
return OpenAiAudioSpeechOptions.builder()
.model(this.model)
.input(this.input)
.voice(this.voice)
.responseFormat(this.responseFormat)
.speed(this.speed)
.build();
}

@Override
public int hashCode() {
final int prime = 31;
Expand Down Expand Up @@ -217,7 +237,7 @@ public Builder responseFormat(AudioResponseFormat responseFormat) {
return this;
}

public Builder speed(Float speed) {
public Builder speed(Double speed) {
this.options.speed = speed;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ public record SpeechRequest(
@JsonProperty("input") String input,
@JsonProperty("voice") String voice,
@JsonProperty("response_format") AudioResponseFormat responseFormat,
@JsonProperty("speed") Float speed) {
@JsonProperty("speed") Double speed) {
// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -491,7 +491,7 @@ public static final class Builder {

private AudioResponseFormat responseFormat = AudioResponseFormat.MP3;

private Float speed;
private Double speed;

public Builder model(String model) {
this.model = model;
Expand All @@ -518,7 +518,7 @@ public Builder responseFormat(AudioResponseFormat responseFormat) {
return this;
}

public Builder speed(Float speed) {
public Builder speed(Double speed) {
this.speed = speed;
return this;
}
Expand Down

This file was deleted.

Loading